mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
Compare commits
17 Commits
7ddb9b29f0
...
v0.6.x
Author | SHA1 | Date | |
---|---|---|---|
ddd86f527a | |||
22d082f9a4 | |||
8ec9bf64a1 | |||
231d725f69 | |||
d69164e1cb | |||
2181d5b33a | |||
33eb86f54f | |||
67f9a21b8e | |||
d2c51dc186 | |||
16c2a21064 | |||
3abcc09d44 | |||
5d378b5f95 | |||
21271aabb9 | |||
7f27789961 | |||
e9ee6b39c6 | |||
9d4554b931 | |||
71cf69e7ad |
2
.github/workflows/e2e_sft.yml
vendored
2
.github/workflows/e2e_sft.yml
vendored
@ -91,7 +91,7 @@ jobs:
|
||||
e2e_sft:
|
||||
needs: setup
|
||||
runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"]
|
||||
timeout-minutes: 25 # Increase this timeout value as needed
|
||||
timeout-minutes: 30 # Increase this timeout value as needed
|
||||
env:
|
||||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
|
||||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
|
||||
|
1
.github/workflows/model.yml
vendored
1
.github/workflows/model.yml
vendored
@ -208,6 +208,7 @@ jobs:
|
||||
|
||||
- name: Running mcore engine tests on 8 L20 GPUs
|
||||
run: |
|
||||
ray stop --force
|
||||
pytest -s -x tests/models/test_engine.py
|
||||
|
||||
cleanup:
|
||||
|
@ -238,6 +238,9 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The
|
||||
- [Vision-SR1](https://github.com/zli12321/Vision-SR1): Self-Rewarding Vision-Language Model via Reasoning Decomposition 
|
||||
- [SimpleVLA-RL](https://github.com/PRIME-RL/SimpleVLA-RL): SimpleVLA-RL: A Simple yet Effective Vision-Language Action Model for Reinforcement Learning 
|
||||
- [Table-R1](https://github.com/Table-R1/Table-R1): Table-R1: Inference-Time Scaling for Table Reasoning 
|
||||
- [Revisual-R1](https://github.com/CSfufu/Revisual-R1): Revisual-R1: Advancing Multimodal Reasoning From Optimized Cold Start to Staged Reinforcement Learning 
|
||||
- [ARES](https://github.com/shawn0728/ARES): ARES: Multimodal Adaptive Reasoning via Difficulty-Aware Token-Level Entropy Shaping 
|
||||
- [Meta-Bandit-LLM](https://github.com/sanxing-chen/meta-bandit-llm): Meta-Bandit-LLM: Long-horizon multiturn interactive training for meta-bandit agents 
|
||||
|
||||
and many more awesome work listed in [recipe](recipe/README.md).
|
||||
|
||||
|
@ -36,6 +36,8 @@ For vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyo
|
||||
|
||||
For SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.
|
||||
|
||||
For latest vLLM with Megatron, please refer to [iseekyan/verl](https://hub.docker.com/r/iseekyan/verl) repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
|
||||
|
||||
See files under ``docker/`` for NGC-based image or if you want to build your own.
|
||||
|
||||
Note that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``
|
||||
|
@ -0,0 +1,15 @@
|
||||
FROM nvcr.io/nvidia/nemo:25.07.gpt_oss
|
||||
|
||||
RUN git clone -b v0.11.0 --depth 1 https://github.com/vllm-project/vllm.git /opt/vllm
|
||||
|
||||
RUN pip install setuptools_scm
|
||||
|
||||
RUN cd /opt/vllm && pip install --no-deps --no-build-isolation --no-cache-dir -e .
|
||||
|
||||
RUN pip install cbor2 setproctitle blake3 openai_harmony pybase64 msgspec partial_json_parser py-cpuinfo diskcache gguf
|
||||
|
||||
RUN pip install --upgrade transformers tokenizers
|
||||
|
||||
RUN pip install codetiming tensordict mathruler pylatexenc
|
||||
|
||||
RUN pip3 install --no-cache-dir mbridge
|
642
docs/advance/rollout_is_migration.md
Normal file
642
docs/advance/rollout_is_migration.md
Normal file
@ -0,0 +1,642 @@
|
||||
# Rollout Importance Sampling - Migration Guide
|
||||
|
||||
Last updated: 10/11/2025.
|
||||
|
||||
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl.
|
||||
|
||||
## References
|
||||
|
||||
- **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
- **Off-policy RL**: https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
## Overview
|
||||
|
||||
Rollout Importance Sampling corrects for distribution mismatch between:
|
||||
- **Rollout policy**: e.g., vLLM with BFloat16
|
||||
- **Training policy**: e.g., FSDP with FP32
|
||||
|
||||
This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.
|
||||
|
||||
## What Changed
|
||||
|
||||
### **Removed (Old Implementation)**
|
||||
|
||||
```yaml
|
||||
# Old TIS configuration (REMOVED)
|
||||
actor:
|
||||
tis_imp_ratio_cap: 2.0 # ❌ No longer supported
|
||||
```
|
||||
|
||||
The old implementation:
|
||||
- Only supported token-level truncate mode
|
||||
- Had no metrics tracking
|
||||
- Lacked numerical stability safeguards
|
||||
- No configurability for different scenarios
|
||||
|
||||
### **Added (New Implementation)**
|
||||
|
||||
```yaml
|
||||
# New Rollout IS configuration (all in algorithm config)
|
||||
algorithm:
|
||||
# Main control: set threshold to enable (null = disabled)
|
||||
rollout_is_threshold: 2.0
|
||||
# Whether to apply weights to loss (default: false = metrics only)
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: null # Auto-reciprocal
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
# REQUIRED: Enable log prob calculation
|
||||
actor_rollout_ref:
|
||||
rollout:
|
||||
calculate_log_probs: true
|
||||
```
|
||||
|
||||
The new implementation:
|
||||
- ✅ Three aggregation levels: token, sequence, geometric
|
||||
- ✅ Two bounding modes: truncate, mask
|
||||
- ✅ Dual threshold support (upper/lower)
|
||||
- ✅ Veto mechanism for catastrophic outliers
|
||||
- ✅ 30+ comprehensive metrics
|
||||
- ✅ Log-space computation for numerical stability
|
||||
- ✅ Memory-efficient implementation
|
||||
|
||||
## Files Modified
|
||||
|
||||
### **Core Implementation**
|
||||
|
||||
1. **NEW**: `verl/trainer/ppo/mismatch_helper.py`
|
||||
- Contains `compute_rollout_importance_weights()` - main function
|
||||
- Contains `compute_is_metrics()` - comprehensive metrics
|
||||
|
||||
2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991)
|
||||
- Replaced old TIS implementation (lines 962-967)
|
||||
- Added new rollout IS with metrics support
|
||||
|
||||
3. **MODIFIED**: `verl/workers/actor/dp_actor.py`
|
||||
- Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap`
|
||||
- Collects and logs all rollout IS metrics
|
||||
|
||||
### **Configuration Files**
|
||||
|
||||
4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100)
|
||||
- Added 6 new rollout IS parameters to `AlgoConfig`
|
||||
|
||||
5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115)
|
||||
- Added 6 new rollout IS parameters to `ActorConfig`
|
||||
|
||||
6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89)
|
||||
- Added rollout IS configuration section
|
||||
|
||||
7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133)
|
||||
- Added rollout IS to algorithm config
|
||||
|
||||
### **Documentation**
|
||||
|
||||
8. **MODIFIED**: `docs/examples/config.rst`
|
||||
- Updated actor config with rollout IS parameters
|
||||
- Updated algorithm config with rollout IS parameters
|
||||
- Added detailed parameter descriptions
|
||||
|
||||
### **Example Scripts**
|
||||
|
||||
9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
|
||||
- Updated from `tis_imp_ratio_cap` to rollout IS parameters
|
||||
- Added comprehensive comments
|
||||
|
||||
10. **NEW**: `examples/rollout_importance_sampling/README.md`
|
||||
- Comprehensive guide with usage patterns
|
||||
- Troubleshooting section
|
||||
- Performance considerations
|
||||
|
||||
11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh`
|
||||
- Basic example with token-level truncate
|
||||
|
||||
### **Tests**
|
||||
|
||||
12. **NEW**: `tests/trainer/ppo/test_rollout_is.py`
|
||||
- Unit tests for rollout IS functionality
|
||||
|
||||
13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py`
|
||||
- Integration tests with PPO
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
### `algorithm.rollout_is_threshold` (float or null)
|
||||
**Main on/off switch.** Upper threshold for IS weights.
|
||||
- `null` = disabled (no computation, no metrics)
|
||||
- `float` value (e.g., 2.0) = enabled (compute weights and metrics)
|
||||
|
||||
### `algorithm.rollout_is` (bool)
|
||||
Whether to apply IS weights to policy loss. Default: `False`
|
||||
- `true` = apply weights to loss (full IS correction)
|
||||
- `false` = compute metrics only (useful for monitoring before enabling)
|
||||
|
||||
**Recommended threshold ranges:**
|
||||
- Token level: 1.5 - 5.0
|
||||
- Sequence level: 2.0 - 10.0
|
||||
- Geometric level: 1.0002 - 1.001
|
||||
|
||||
### `algorithm.rollout_is_threshold_lower` (float or null)
|
||||
Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal).
|
||||
|
||||
### `algorithm.rollout_is_level` (str)
|
||||
Aggregation level for IS weights:
|
||||
- `"token"`: Per-token ratios
|
||||
- `"sequence"`: Product of ratios
|
||||
- `"geometric"`: Geometric mean (experimental)
|
||||
|
||||
### `algorithm.rollout_is_mode` (str)
|
||||
Bounding mode:
|
||||
- `"truncate"`: Cap weights at upper threshold only
|
||||
- `"mask"`: Zero out weights outside [lower, upper]
|
||||
|
||||
### `algorithm.rollout_is_veto_threshold` (float)
|
||||
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
|
||||
Default: `1e-4` (ratio 10,000x off)
|
||||
|
||||
## Migration Steps
|
||||
|
||||
### Step 1: Update Your Configuration
|
||||
|
||||
**Before (Old):**
|
||||
```yaml
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
tis_imp_ratio_cap: 2.0
|
||||
rollout:
|
||||
calculate_log_probs: true
|
||||
```
|
||||
|
||||
**After (New):**
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0 # Main control
|
||||
rollout_is: true # Apply to loss (default: false)
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
|
||||
actor_rollout_ref:
|
||||
rollout:
|
||||
calculate_log_probs: true # Still required!
|
||||
```
|
||||
|
||||
### Step 2: Monitor New Metrics
|
||||
|
||||
All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs.
|
||||
|
||||
#### **Core IS Weight Metrics**
|
||||
|
||||
- **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens
|
||||
- **Ideal value**: Close to 1.0 (indicates minimal distribution mismatch)
|
||||
- **Warning**: < 0.5 or > 2.0 suggests significant policy mismatch
|
||||
|
||||
- **`rollout_is_std`**: Standard deviation of IS weights
|
||||
- **Ideal value**: < 0.5 for stable training
|
||||
- **Warning**: > 1.0 indicates high variance, may need tighter thresholds
|
||||
|
||||
- **`rollout_is_min`**: Minimum IS weight observed
|
||||
- Shows the most underweighted token/sequence
|
||||
|
||||
- **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking)
|
||||
- Shows the most overweighted token/sequence
|
||||
- Compare with `rollout_is_threshold` to see truncation impact
|
||||
|
||||
#### **Percentile Metrics**
|
||||
|
||||
- **`rollout_is_p25`**: 25th percentile of IS weights
|
||||
- **`rollout_is_p50`**: Median IS weight (50th percentile)
|
||||
- Should be close to `rollout_is_mean` if distribution is symmetric
|
||||
- **`rollout_is_p75`**: 75th percentile of IS weights
|
||||
- **`rollout_is_p95`**: 95th percentile of IS weights
|
||||
- Use to detect outliers
|
||||
- **`rollout_is_p99`**: 99th percentile of IS weights
|
||||
- Should be close to `rollout_is_threshold` if truncation is working
|
||||
|
||||
#### **Effective Sample Size**
|
||||
|
||||
- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting
|
||||
- **Formula**: `1 / mean(weights²)` where weights are normalized
|
||||
- **Range**: 0.0 to 1.0 (as fraction of original batch)
|
||||
- **Ideal value**: > 0.5 (retaining at least 50% effective samples)
|
||||
- **Warning**: < 0.3 means high variance, losing too many effective samples
|
||||
|
||||
#### **Veto Mechanism Metrics**
|
||||
|
||||
- **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism
|
||||
- **Ideal value**: < 0.05 (less than 5% vetoed)
|
||||
- **Warning**: > 0.1 suggests policies are too different or numerical issues
|
||||
|
||||
- **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold
|
||||
- Identifies problematic tokens before sequence-level veto
|
||||
- **Warning**: > 0.01 indicates widespread distribution issues
|
||||
|
||||
#### **Threshold Exceedance Metrics**
|
||||
|
||||
- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
|
||||
- Shows how often truncation/masking occurs on high end
|
||||
- **Ideal value**: < 0.1 (most weights within bounds)
|
||||
|
||||
- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
|
||||
- Shows how often masking occurs on low end (mask mode only)
|
||||
- **Ideal value**: < 0.1
|
||||
|
||||
#### **Sequence-Level Metrics** (for sequence/geometric modes)
|
||||
|
||||
- **`rollout_is_seq_mean`**: Mean IS weight at sequence level
|
||||
- Should match `rollout_is_mean` for sequence-level aggregation
|
||||
|
||||
- **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights
|
||||
|
||||
- **`rollout_is_seq_min`**: Minimum sequence-level IS weight
|
||||
|
||||
- **`rollout_is_seq_max`**: Maximum sequence-level IS weight
|
||||
|
||||
- **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level
|
||||
- **Ideal value**: < 1.0
|
||||
- Shows worst-case sequence mismatch
|
||||
|
||||
- **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold
|
||||
|
||||
- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold
|
||||
|
||||
#### **Masking Metrics** (mask mode only)
|
||||
|
||||
- **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero)
|
||||
- **Ideal value**: < 0.1
|
||||
- **Warning**: > 0.3 means losing too much data
|
||||
|
||||
- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token
|
||||
- Shows sequence-level impact of masking
|
||||
|
||||
#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)
|
||||
|
||||
- **`mismatch_training_ppl`**: Perplexity of training policy (e.g., FSDP FP32)
|
||||
- **Formula**: `exp(-mean(log_probs))`
|
||||
- Lower is better (model is more confident)
|
||||
|
||||
- **`mismatch_rollout_ppl`**: Perplexity of rollout policy (e.g., vLLM BF16)
|
||||
- Should be close to `mismatch_training_ppl` if policies match well
|
||||
|
||||
- **`mismatch_ppl_ratio`**: Ratio of training PPL to rollout PPL
|
||||
- **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))`
|
||||
- **Ideal value**: Close to 1.0
|
||||
- **Meaning**: > 1.0 means training is less confident than rollout
|
||||
|
||||
- **`mismatch_training_log_ppl`**: Log perplexity of training policy
|
||||
- Useful for identifying trends (linear scale)
|
||||
|
||||
- **`mismatch_rollout_log_ppl`**: Log perplexity of rollout policy
|
||||
|
||||
- **`mismatch_log_ppl_diff`**: Mean difference in log perplexities
|
||||
- **Formula**: `mean(log_ppl_rollout - log_ppl_training)`
|
||||
- **Ideal value**: Close to 0.0
|
||||
- Sign indicates which policy is more confident
|
||||
|
||||
- **`mismatch_log_ppl_abs_diff`**: Mean absolute log perplexity difference
|
||||
- Magnitude of mismatch regardless of direction
|
||||
|
||||
- **`mismatch_log_ppl_diff_max`**: Maximum log perplexity difference across sequences
|
||||
- Identifies worst-case sequence
|
||||
|
||||
- **`mismatch_log_ppl_diff_min`**: Minimum log perplexity difference across sequences
|
||||
|
||||
- **`mismatch_kl`**: KL divergence KL(π_rollout || π_training)
|
||||
- **Formula**: `mean(log_prob_rollout - log_prob_training)`
|
||||
- **Ideal value**: Close to 0.0 (policies match)
|
||||
- **Warning**: > 0.1 indicates significant mismatch
|
||||
- **Note**: Can be negative (rollout is less confident)
|
||||
|
||||
- **`mismatch_k3_kl`**: K3 KL estimator
|
||||
- **Formula**: `mean(exp(log_ratio) - log_ratio - 1)`
|
||||
- More stable for small KL values
|
||||
- Always non-negative
|
||||
|
||||
#### **Example: Accessing Metrics in Code**
|
||||
|
||||
```python
|
||||
# Metrics are returned from compute_rollout_importance_weights
|
||||
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
|
||||
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=training_log_probs, # from training policy
|
||||
rollout_log_prob=rollout_log_probs, # from rollout policy
|
||||
response_mask=response_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
# All metrics have 'mismatch/' prefix
|
||||
print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}")
|
||||
print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}")
|
||||
print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}")
|
||||
print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}")
|
||||
|
||||
# Check for warning conditions
|
||||
if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0:
|
||||
print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected")
|
||||
|
||||
if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3:
|
||||
print("⚠️ Warning: Low effective sample size, high variance in IS weights")
|
||||
|
||||
if metrics['mismatch/rollout_is_veto_fraction'] > 0.1:
|
||||
print("⚠️ Warning: High veto fraction, policies may be too different")
|
||||
```
|
||||
|
||||
#### **Example: Monitoring Metrics During Training**
|
||||
|
||||
```python
|
||||
# In your training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# ... rollout phase ...
|
||||
|
||||
# Compute IS weights and get metrics
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=batch.old_log_prob,
|
||||
rollout_log_prob=batch.rollout_log_prob,
|
||||
response_mask=batch.response_mask,
|
||||
rollout_is_level=config.rollout_is_level,
|
||||
rollout_is_mode=config.rollout_is_mode,
|
||||
rollout_is_threshold=config.rollout_is_threshold,
|
||||
rollout_is_veto_threshold=config.rollout_is_veto_threshold,
|
||||
)
|
||||
|
||||
# Log to tensorboard/wandb
|
||||
for metric_name, metric_value in metrics.items():
|
||||
logger.log_scalar(metric_name, metric_value, step=global_step)
|
||||
|
||||
# Use IS weights in training
|
||||
is_weights = weights_proto.batch["rollout_is_weights"]
|
||||
# ... apply weights to policy gradient ...
|
||||
```
|
||||
|
||||
#### **Example: Conditional Alerting Based on Metrics**
|
||||
|
||||
```python
|
||||
def check_rollout_is_health(metrics, config):
|
||||
"""Check if rollout IS metrics indicate healthy training."""
|
||||
warnings = []
|
||||
|
||||
# Check mean IS weight
|
||||
mean_weight = metrics['mismatch/rollout_is_mean']
|
||||
if mean_weight < 0.5 or mean_weight > 2.0:
|
||||
warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0")
|
||||
|
||||
# Check effective sample size
|
||||
ess = metrics['mismatch/rollout_is_eff_sample_size']
|
||||
if ess < 0.3:
|
||||
warnings.append(f"Effective sample size {ess:.3f} is too low")
|
||||
|
||||
# Check veto fraction
|
||||
veto_frac = metrics['mismatch/rollout_is_veto_fraction']
|
||||
if veto_frac > 0.1:
|
||||
warnings.append(f"Veto fraction {veto_frac:.3f} is too high")
|
||||
|
||||
# Check variance
|
||||
std = metrics['mismatch/rollout_is_std']
|
||||
if std > 1.0:
|
||||
warnings.append(f"IS weight std {std:.3f} is too high")
|
||||
|
||||
# Check KL divergence
|
||||
kl = metrics['mismatch/mismatch_kl']
|
||||
if abs(kl) > 0.1:
|
||||
warnings.append(f"KL divergence {kl:.3f} indicates significant mismatch")
|
||||
|
||||
if warnings:
|
||||
print("⚠️ Rollout IS Health Warnings:")
|
||||
for warning in warnings:
|
||||
print(f" - {warning}")
|
||||
return False
|
||||
else:
|
||||
print("✅ Rollout IS metrics look healthy")
|
||||
return True
|
||||
|
||||
# Use in training
|
||||
_, metrics = compute_rollout_importance_weights(...)
|
||||
is_healthy = check_rollout_is_health(metrics, config)
|
||||
|
||||
if not is_healthy:
|
||||
# Consider adjusting config or investigating issues
|
||||
print("Consider:")
|
||||
print(" - Tightening rollout_is_threshold")
|
||||
print(" - Switching to geometric aggregation level")
|
||||
print(" - Checking if rollout and training policies are too different")
|
||||
```
|
||||
|
||||
### Step 3: Test Your Training
|
||||
|
||||
Start with the basic token-level truncate configuration:
|
||||
```bash
|
||||
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
|
||||
```
|
||||
|
||||
Monitor metrics for 1-2 epochs before adjusting parameters.
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Example 1: Full IS Correction
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: true # Apply weights to loss
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 2: Metrics Only (Monitoring Mode)
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: false # Compute metrics, don't apply weights
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 3: Geometric Mean with Mask
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 1.0002
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.9998
|
||||
rollout_is_level: geometric
|
||||
rollout_is_mode: mask
|
||||
```
|
||||
|
||||
### Example 4: Asymmetric Thresholds
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 5.0
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.8
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: mask
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: High variance in IS weights
|
||||
**Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
|
||||
|
||||
**Solutions:**
|
||||
1. Switch from `sequence` to `geometric` level
|
||||
2. Tighten thresholds
|
||||
3. Verify rollout and training aren't too different
|
||||
|
||||
### Issue: Too many sequences vetoed
|
||||
**Symptoms:** `rollout_is_veto_fraction` > 0.1
|
||||
|
||||
**Solutions:**
|
||||
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
|
||||
2. Check for numerical issues in log prob computation
|
||||
3. Verify policies aren't completely different
|
||||
|
||||
### Issue: Mean IS weight far from 1.0
|
||||
**Symptoms:** `rollout_is_mean` < 0.5 or > 2.0
|
||||
|
||||
**Solutions:**
|
||||
1. Verify `calculate_log_probs=True` is set
|
||||
2. Check rollout_log_probs are correctly passed
|
||||
3. Check for systematic bias
|
||||
|
||||
### Debugging: Visualizing Metrics
|
||||
|
||||
**Example: Plot IS weight distribution**
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def plot_is_metrics(metrics_history):
|
||||
"""Plot rollout IS metrics over training steps."""
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||||
|
||||
# Plot 1: Mean IS weight over time
|
||||
axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean'])
|
||||
axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
|
||||
axes[0, 0].set_title('Mean IS Weight')
|
||||
axes[0, 0].set_xlabel('Step')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Plot 2: Effective sample size
|
||||
axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size'])
|
||||
axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')
|
||||
axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')
|
||||
axes[0, 1].set_title('Effective Sample Size')
|
||||
axes[0, 1].set_xlabel('Step')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Plot 3: Veto fraction
|
||||
axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction'])
|
||||
axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning')
|
||||
axes[0, 2].set_title('Veto Fraction')
|
||||
axes[0, 2].set_xlabel('Step')
|
||||
axes[0, 2].legend()
|
||||
|
||||
# Plot 4: IS weight distribution (latest step)
|
||||
latest_idx = -1
|
||||
percentiles = [25, 50, 75, 95, 99]
|
||||
values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles]
|
||||
axes[1, 0].bar([f'p{p}' for p in percentiles], values)
|
||||
axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
|
||||
axes[1, 0].set_title('IS Weight Percentiles (Latest)')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Plot 5: KL divergence over time
|
||||
axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
|
||||
axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
|
||||
axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3)
|
||||
axes[1, 1].set_title('KL Divergence')
|
||||
axes[1, 1].set_xlabel('Step')
|
||||
axes[1, 1].legend()
|
||||
|
||||
# Plot 6: PPL ratio over time
|
||||
axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
|
||||
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
|
||||
axes[1, 2].set_title('PPL Ratio (Training/Rollout)')
|
||||
axes[1, 2].set_xlabel('Step')
|
||||
axes[1, 2].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('rollout_is_metrics.png', dpi=150)
|
||||
print("Saved plot to rollout_is_metrics.png")
|
||||
```
|
||||
|
||||
**Example: Metric collection during training**
|
||||
|
||||
```python
|
||||
# Collect metrics over time
|
||||
metrics_history = {
|
||||
'mismatch/rollout_is_mean': [],
|
||||
'mismatch/rollout_is_eff_sample_size': [],
|
||||
'mismatch/rollout_is_veto_fraction': [],
|
||||
'mismatch/rollout_is_p25': [],
|
||||
'mismatch/rollout_is_p50': [],
|
||||
'mismatch/rollout_is_p75': [],
|
||||
'mismatch/rollout_is_p95': [],
|
||||
'mismatch/rollout_is_p99': [],
|
||||
'mismatch/mismatch_kl': [],
|
||||
'mismatch/mismatch_k3_kl': [],
|
||||
'mismatch/mismatch_ppl_ratio': [],
|
||||
}
|
||||
|
||||
# In training loop
|
||||
for step in range(num_steps):
|
||||
# ... compute IS weights ...
|
||||
_, metrics = compute_rollout_importance_weights(...)
|
||||
|
||||
# Store metrics
|
||||
for key in metrics_history.keys():
|
||||
if key in metrics:
|
||||
metrics_history[key].append(metrics[key])
|
||||
|
||||
# Plot every 100 steps
|
||||
if step % 100 == 0:
|
||||
plot_is_metrics(metrics_history)
|
||||
```
|
||||
|
||||
## Performance Impact
|
||||
|
||||
- **Memory overhead**: ~1% of model memory
|
||||
- **Computational overhead**: 1-3% depending on level
|
||||
- **Training stability**: Significantly improved when mismatch exists
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
**The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode.
|
||||
|
||||
All scripts and configurations must be updated to use the new rollout IS parameters.
|
||||
|
||||
## Testing
|
||||
|
||||
Run the test suite to verify everything works:
|
||||
|
||||
```bash
|
||||
# Basic unit tests
|
||||
python test_rollout_is.py
|
||||
|
||||
# Integration tests (if pytest is available)
|
||||
pytest tests/trainer/ppo/test_rollout_is_integration.py -v
|
||||
```
|
||||
|
||||
Expected output: All tests pass ✓
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- **Implementation**: `verl/trainer/ppo/mismatch_helper.py`
|
||||
- **Examples**: `examples/rollout_importance_sampling/`
|
||||
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
|
||||
|
||||
## Summary
|
||||
|
||||
The new Rollout Importance Sampling implementation provides:
|
||||
- ✅ More robust handling of distribution mismatch
|
||||
- ✅ Better numerical stability
|
||||
- ✅ Comprehensive metrics for monitoring
|
||||
- ✅ Flexibility for different scenarios
|
||||
- ✅ Memory-efficient computation
|
||||
|
||||
Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section.
|
@ -118,7 +118,13 @@ Actor/Rollout/Reference Policy
|
||||
clip_ratio: 0.2
|
||||
entropy_coeff: 0.0
|
||||
use_kl_loss: False # True for GRPO
|
||||
tis_imp_ratio_cap: -1 # set to positive values for Truncated Importance Sampling (requires setting `rollout.calculate_log_probs` as True)
|
||||
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
|
||||
rollout_is: False # Enable IS correction
|
||||
rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
|
||||
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
|
||||
rollout_is_level: token # Aggregation: token/sequence/geometric
|
||||
rollout_is_mode: truncate # Bounding: truncate/mask
|
||||
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
|
||||
use_torch_compile: True # False to disable torch compile
|
||||
kl_loss_coef: 0.001 # for grpo
|
||||
kl_loss_type: low_var_kl # for grpo
|
||||
@ -132,7 +138,7 @@ Actor/Rollout/Reference Policy
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: 0.0 # only used with cosine lr scheduler, default to 0.0
|
||||
num_cycles: 0.5 # only used with cosine lr scheduler, default to 0.5
|
||||
warmup_style: constant # select from constant/cosine
|
||||
lr_scheduler_type: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
fsdp_config:
|
||||
wrap_policy:
|
||||
@ -415,7 +421,7 @@ ____________________________________________________
|
||||
|
||||
Notice that there are some differences in APIs between Megatron optimizer and FSDP optimizer.
|
||||
|
||||
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup.
|
||||
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``lr_scheduler_type`` actually means the style of lr decay after warmup.
|
||||
- Megatron optimizer also support weight decay decay mechanism
|
||||
- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training.
|
||||
|
||||
@ -498,6 +504,13 @@ Algorithm
|
||||
kl_coef: 0.005
|
||||
horizon: 10000
|
||||
target_kl: 0.1
|
||||
# Rollout Importance Sampling
|
||||
rollout_is: False
|
||||
rollout_is_threshold: null
|
||||
rollout_is_threshold_lower: null
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
- ``gamma``: discount factor
|
||||
- ``lam``: Trade-off between bias and variance in the GAE estimator
|
||||
@ -510,6 +523,13 @@ Algorithm
|
||||
- ``kl_coef``: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.
|
||||
- ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.
|
||||
- ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details.
|
||||
- ``rollout_is``: Whether to enable rollout importance sampling correction. Default is False.
|
||||
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
|
||||
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
|
||||
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
|
||||
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).
|
||||
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
|
||||
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.
|
||||
|
||||
Trainer
|
||||
~~~~~~~
|
||||
|
@ -121,6 +121,7 @@ verl is fast with:
|
||||
examples/sandbox_fusion_example
|
||||
advance/rollout_trace.rst
|
||||
advance/rollout_skip.rst
|
||||
advance/rollout_is_migration.md
|
||||
advance/one_step_off
|
||||
advance/agent_loop
|
||||
|
||||
|
@ -79,7 +79,7 @@ For latest vLLM with FSDP, please refer to `hiyouga/verl <https://hub.docker.com
|
||||
|
||||
For latest SGLang with FSDP, please refer to `hebiaobuaa/verl <https://hub.docker.com/r/hebiaobuaa/verl>`_ repository and the latest version is ``hebiaobuaa/verl:app-verl0.5-sglang0.4.9.post6-mcore0.12.2-te2.2`` which is provided by SGLang RL Group.
|
||||
|
||||
For latest vLLM with Megatron, please refer to `iseekyan/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.15.0-te2.7`
|
||||
For latest vLLM with Megatron, please refer to `iseekyan/verl <https://hub.docker.com/r/iseekyan/verl>`_ repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
|
||||
|
||||
See files under ``docker/`` for NGC-based image or if you want to build your own.
|
||||
|
||||
|
@ -6,21 +6,20 @@
|
||||
This is the official implementaion of paper [***Geometric-Mean Policy Optimization***](https://arxiv.org/abs/2507.20673).
|
||||
|
||||
<div align=center>
|
||||
<img width="3092" height="864" alt="image" src="https://github.com/user-attachments/assets/af4c7e0f-923a-45ef-9bcf-57109b8ee61e" />
|
||||
<img width="3092" height="864" alt="image" src="https://github.com/user-attachments/assets/20b04c4e-7ee8-4775-9af8-33c0158336e2" />
|
||||
</div>
|
||||
|
||||
|
||||
## 1. Contents
|
||||
- Geometric-Mean Policy Optimization
|
||||
- [1. Contents](#1-contents)
|
||||
- [2. Introduction](#2-introduction)
|
||||
- [3. Code Usage](#4-code-usage)
|
||||
- [4. Contacts](#5-contacts)
|
||||
- [5. Citation](#7-citation)
|
||||
- [3. Code Usage](#3-code-usage)
|
||||
- [4. Contacts](#4-contacts)
|
||||
- [5. Citation](#5-citation)
|
||||
|
||||
## 2. Introduction
|
||||
|
||||
Recent advancements, such as Group Relative Policy Optimization (GRPO), have enhanced the reasoning capabilities of large language models by optimizing the arithmetic mean of token-level rewards. However, GRPO suffers from unstable policy updates when processing tokens with outlier importance-weighted rewards, which manifests as extreme importance sampling ratios during training, i.e., the ratio between the sampling probabilities assigned to a token by the current and old policies. In this work, we propose Geometric-Mean Policy Optimization (GMPO), a stabilized variant of GRPO. Instead of optimizing the arithmetic mean, GMPO maximizes the geometric mean of token-level rewards, which is inherently less sensitive to outliers and maintains a more stable range of importance sampling ratio. In addition, we provide comprehensive theoretical and experimental analysis to justify the design and stability benefits of GMPO. Beyond improved stability, GMPO-7B outperforms GRPO by an average of 4.1% on multiple mathematical benchmarks and 1.4% on multimodal reasoning benchmark, including AIME24, AMC, MATH500, OlympiadBench, Minerva, and Geometry3K.
|
||||
Group Relative Policy Optimization (GRPO) has significantly enhanced the reasoning capability of large language models by optimizing the arithmetic mean of token-level rewards. Unfortunately, GRPO is observed to suffer from unstable policy updates when facing tokens with outlier importance-weighted rewards, which manifest as extreme importance sampling ratios during training. In this study, we propose Geometric-Mean Policy Optimization (GMPO), with the aim to improve the stability of GRPO through suppressing token reward outliers. Instead of optimizing the arithmetic mean, GMPO maximizes the geometric mean of token-level rewards, which is inherently less sensitive to outliers and maintains a more stable range of importance sampling ratio. GMPO is plug-and-play—simply replacing GRPO's arithmetic mean with the geometric mean of token-level rewards, as the latter is inherently less sensitive to outliers. GMPO is theoretically plausible—analysis reveals that both GMPO and GRPO are weighted forms of the policy gradient while the former enjoys more stable weights, which consequently benefits policy optimization and performance. Experiments on multiple mathematical reasoning benchmarks show that GMPO-7B improves the average Pass@1 of GRPO by up to 4.1%, outperforming many state-of-the-art approaches.
|
||||
|
||||
## 3. Code Usage
|
||||
|
||||
@ -30,7 +29,7 @@ clip_ratio_low=0.4
|
||||
clip_ratio_high=0.4
|
||||
loss_mode=geo_mean
|
||||
```
|
||||
|
||||
We observed that using a large clip ratio during Mixture-of-Experts (MoE) model training often leads to optimization instability. When training MoE models, consider lowering the clip ratio to achieve more stable convergence.
|
||||
To get started quickly, run:
|
||||
```
|
||||
bash examples/gmpo_trainer/run_qwen2_5-7b_math.sh
|
||||
@ -51,13 +50,10 @@ If you have any question about our work or this repository, please don't hesitat
|
||||
|
||||
## 5. Citation
|
||||
```
|
||||
@misc{zhao2025geometricmeanpolicyoptimization,
|
||||
title={Geometric-Mean Policy Optimization},
|
||||
author={Yuzhong Zhao and Yue Liu and Junpeng Liu and Jingye Chen and Xun Wu and Yaru Hao and Tengchao Lv and Shaohan Huang and Lei Cui and Qixiang Ye and Fang Wan and Furu Wei},
|
||||
year={2025},
|
||||
eprint={2507.20673},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2507.20673},
|
||||
@article{zhao2025geometric,
|
||||
title={Geometric-mean policy optimization},
|
||||
author={Zhao, Yuzhong and Liu, Yue and Liu, Junpeng and Chen, Jingye and Wu, Xun and Hao, Yaru and Lv, Tengchao and Huang, Shaohan and Cui, Lei and Ye, Qixiang and others},
|
||||
journal={arXiv preprint arXiv:2507.20673},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
|
79
examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh
Normal file
79
examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh
Normal file
@ -0,0 +1,79 @@
|
||||
set -x
|
||||
ENGINE=${1:-vllm}
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
|
||||
|
||||
# VLLM version >= 0.11.0 for qwen3-vl support, recommend to use container docker://iseekyan/verl:nemo.gptoss_vllm0.11.0
|
||||
# pip install -U git+https://github.com/ISEEKYAN/mbridge.git # for latest mbridge
|
||||
# pip install -U transformers # for qwen3-vl support
|
||||
# pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1 # for megatron-lm0.13.1
|
||||
|
||||
|
||||
export VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP
|
||||
|
||||
|
||||
HF_MODEL_PATH=${HF_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct"}
|
||||
|
||||
|
||||
train_path=$HOME/data/geo3k/train.parquet
|
||||
test_path=$HOME/data/geo3k/test.parquet
|
||||
|
||||
python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
--config-name='ppo_megatron_trainer.yaml'\
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files="$train_path" \
|
||||
data.val_files="$test_path" \
|
||||
data.train_batch_size=512 \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=2048 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.path=$HF_MODEL_PATH \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.actor.megatron.expert_model_parallel_size=8 \
|
||||
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.01 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=True \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \
|
||||
actor_rollout_ref.rollout.name=$ENGINE \
|
||||
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
||||
actor_rollout_ref.rollout.n=5 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.actor.megatron.use_mbridge=True \
|
||||
actor_rollout_ref.actor.megatron.param_offload=True \
|
||||
actor_rollout_ref.actor.megatron.optimizer_offload=True \
|
||||
actor_rollout_ref.actor.megatron.grad_offload=True \
|
||||
actor_rollout_ref.ref.megatron.param_offload=True \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger='["console","wandb"]' \
|
||||
trainer.project_name='verl_grpo_example_geo3k' \
|
||||
trainer.experiment_name='qwen3_vl_30b_megatron' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=20 \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=15 $@
|
242
examples/rollout_importance_sampling/README.md
Normal file
242
examples/rollout_importance_sampling/README.md
Normal file
@ -0,0 +1,242 @@
|
||||
# Rollout Importance Sampling (IS) Examples
|
||||
|
||||
This directory contains examples and documentation for using Rollout Importance Sampling to correct distribution mismatch between rollout and training policies.
|
||||
|
||||
**References:**
|
||||
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
## Overview
|
||||
|
||||
Rollout Importance Sampling corrects for distribution mismatch when:
|
||||
1. **Rollout generation** uses one policy (e.g., vLLM with BFloat16)
|
||||
2. **Training** uses another policy (e.g., FSDP with FP32)
|
||||
3. This mismatch leads to biased gradient estimates
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
# Main control: set threshold to enable (null = disabled)
|
||||
rollout_is_threshold: 2.0
|
||||
# Whether to apply weights to policy loss (true) or just compute metrics (false)
|
||||
rollout_is: true
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# IMPORTANT: Must enable log prob calculation
|
||||
actor_rollout_ref:
|
||||
rollout:
|
||||
calculate_log_probs: true
|
||||
```
|
||||
|
||||
### Running the Example
|
||||
|
||||
```bash
|
||||
# Basic example with token-level truncate
|
||||
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Aggregation Levels (`rollout_is_level`)
|
||||
|
||||
| Level | Properties | Threshold Range |
|
||||
|-------|-----------|-----------------|
|
||||
| **token** | Per-token | 1.5 - 5.0 |
|
||||
| **sequence** | Per-sequence | 2.0 - 10.0 |
|
||||
| **geometric** | Geometric mean | 1.0002 - 1.001 |
|
||||
|
||||
### Bounding Modes (`rollout_is_mode`)
|
||||
|
||||
| Mode | Behavior |
|
||||
|------|----------|
|
||||
| **truncate** | Cap weights at upper threshold only |
|
||||
| **clip** | Zero out weights outside [lower, upper] |
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.**
|
||||
- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false.
|
||||
- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper)
|
||||
- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4)
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Example 1: Full IS Correction (Apply Weights)
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: true # Apply to loss
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
```
|
||||
|
||||
### Example 2: Metrics Only (No Weight Application)
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: false # Compute metrics only, don't apply to loss
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 3: Geometric Mean with Mask
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 1.0002
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.9998
|
||||
rollout_is_level: geometric
|
||||
rollout_is_mode: mask
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
```
|
||||
|
||||
### Example 4: Sequence-level with Truncate
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 5.0
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: null # Auto-reciprocal: 0.2
|
||||
rollout_is_level: sequence
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
```
|
||||
|
||||
### Example 5: Asymmetric Thresholds
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 5.0
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.8
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: mask
|
||||
```
|
||||
|
||||
## Monitoring Metrics
|
||||
|
||||
Key metrics to watch (all prefixed with `mismatch/` in logs):
|
||||
|
||||
### Health Indicators
|
||||
- `rollout_is_mean`: Mean IS weight across sequences
|
||||
- `rollout_is_eff_sample_size`: Effective sample size after weighting
|
||||
- `rollout_is_veto_fraction`: Fraction of sequences vetoed
|
||||
|
||||
### Distribution Metrics
|
||||
- `rollout_is_max`, `rollout_is_min`: Weight extremes
|
||||
- `rollout_is_std`: Standard deviation
|
||||
- `rollout_is_p50`, `rollout_is_p95`, `rollout_is_p99`: Percentiles
|
||||
|
||||
### Diagnostic Metrics
|
||||
- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold
|
||||
- `rollout_is_ratio_fraction_low`: Fraction below lower threshold
|
||||
- `rollout_is_catastrophic_token_fraction`: Catastrophic tokens detected
|
||||
|
||||
### Mismatch Metrics (Training vs Rollout Policy)
|
||||
|
||||
These metrics help diagnose the distribution mismatch between rollout and training policies:
|
||||
|
||||
**Perplexity Metrics:**
|
||||
- `mismatch_training_ppl`: Perplexity of training policy
|
||||
- `mismatch_rollout_ppl`: Perplexity of rollout policy
|
||||
- `mismatch_ppl_ratio`: Ratio of training PPL to rollout PPL
|
||||
- `mismatch_log_ppl_diff`: Log perplexity difference
|
||||
|
||||
**KL Divergence Metrics:**
|
||||
- `mismatch_kl`: KL divergence KL(π_rollout || π_training)
|
||||
- `mismatch_k3_kl`: K3 KL estimator
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: High Variance in IS Weights
|
||||
|
||||
**Symptoms**: `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
|
||||
|
||||
**Solutions**:
|
||||
1. Switch from `sequence` to `geometric` level
|
||||
2. Tighten thresholds
|
||||
3. Check if rollout and training are too different
|
||||
|
||||
### Issue: Too Many Sequences Vetoed
|
||||
|
||||
**Symptoms**: `rollout_is_veto_fraction` > 0.1
|
||||
|
||||
**Solutions**:
|
||||
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
|
||||
2. Check for numerical issues in log prob computation
|
||||
3. Verify rollout and training policies aren't completely different
|
||||
|
||||
### Issue: Mean IS Weight Far from 1.0
|
||||
|
||||
**Symptoms**: `rollout_is_mean` < 0.5 or > 2.0
|
||||
|
||||
**Solutions**:
|
||||
1. Check that `calculate_log_probs=True` is set
|
||||
2. Verify rollout_log_probs are correctly passed
|
||||
3. Check for systematic bias in rollout vs training
|
||||
|
||||
### Issue: Too Much Data Discarded (Mask Mode)
|
||||
|
||||
**Symptoms**: `rollout_is_masked_fraction` > 0.5
|
||||
|
||||
**Solutions**:
|
||||
1. Widen thresholds
|
||||
2. Switch to `truncate` mode
|
||||
3. Use `geometric` level for better stability
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
- Rollout IS adds minimal memory overhead (~1% of model memory)
|
||||
- Log-space computation prevents numerical overflow
|
||||
|
||||
### Computational Cost
|
||||
- Token-level: ~1-2% overhead
|
||||
- Sequence-level: ~2-3% overhead
|
||||
- Geometric: ~2-3% overhead
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Dual Thresholds
|
||||
|
||||
Specify both upper and lower explicitly:
|
||||
|
||||
```yaml
|
||||
rollout_is_threshold: 2.0 # Upper
|
||||
rollout_is_threshold_lower: 0.5 # Lower (not 1/2.0 = 0.5)
|
||||
```
|
||||
|
||||
Or use auto-reciprocal:
|
||||
|
||||
```yaml
|
||||
rollout_is_threshold: 2.0 # Upper = 2.0, Lower = 0.5 (auto)
|
||||
rollout_is_threshold_lower: null
|
||||
```
|
||||
|
||||
### Veto Mechanism
|
||||
|
||||
The veto mechanism zeros out entire sequences containing catastrophic outliers:
|
||||
|
||||
- If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected
|
||||
- This prevents extreme outliers from dominating training
|
||||
- Default threshold: 1e-4 (ratio 10,000x off)
|
||||
- Set to `null` to disable: `rollout_is_veto_threshold: null`
|
||||
|
||||
## Examples
|
||||
|
||||
See the script in this directory:
|
||||
- `run_with_rollout_is.sh`: Basic example with token-level truncate mode
|
||||
|
||||
## References
|
||||
|
||||
- Implementation: `verl/trainer/ppo/mismatch_helper.py`
|
||||
- Core algorithm: `verl/trainer/ppo/core_algos.py`
|
||||
- Paper: "Your Efficient RL Framework Secretly Brings You Off-Policy RL Training"
|
99
examples/rollout_importance_sampling/run_with_rollout_is.sh
Executable file
99
examples/rollout_importance_sampling/run_with_rollout_is.sh
Executable file
@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env bash
|
||||
# Example: Basic PPO training with Rollout Importance Sampling
|
||||
# This demonstrates the standard setup for correcting distribution mismatch
|
||||
|
||||
set -xeuo pipefail
|
||||
|
||||
# ==============================================================================
|
||||
# Rollout Importance Sampling Configuration
|
||||
# ==============================================================================
|
||||
|
||||
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
|
||||
rollout_is_threshold=2.0
|
||||
|
||||
# Whether to apply IS weights to policy loss
|
||||
# true = apply weights to loss, false = compute metrics only
|
||||
rollout_is=true
|
||||
|
||||
# Lower threshold (null = auto-reciprocal, i.e., 1/upper = 0.5)
|
||||
rollout_is_threshold_lower=null
|
||||
|
||||
# Aggregation level: token | sequence | geometric (experimental)
|
||||
rollout_is_level=token
|
||||
|
||||
# Bounding mode: truncate (cap upper) | mask (zero outside bounds)
|
||||
rollout_is_mode=truncate
|
||||
|
||||
# Catastrophic outlier veto threshold
|
||||
rollout_is_veto_threshold=1e-4
|
||||
|
||||
# ==============================================================================
|
||||
# Model and Data Configuration
|
||||
# ==============================================================================
|
||||
|
||||
MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2.5-7B"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"data/train.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"data/test.parquet"}
|
||||
|
||||
max_prompt_length=512
|
||||
max_response_length=1024
|
||||
|
||||
# ==============================================================================
|
||||
# Training Configuration
|
||||
# ==============================================================================
|
||||
|
||||
train_batch_size=128
|
||||
ppo_mini_batch_size=32
|
||||
ppo_epochs=1
|
||||
learning_rate=5e-7
|
||||
|
||||
# ==============================================================================
|
||||
# Algorithm Configuration
|
||||
# ==============================================================================
|
||||
|
||||
adv_estimator=gae
|
||||
gamma=1.0
|
||||
lam=0.95
|
||||
|
||||
# ==============================================================================
|
||||
# Launch Training
|
||||
# ==============================================================================
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_batch_size} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.gamma=${gamma} \
|
||||
algorithm.lam=${lam} \
|
||||
algorithm.rollout_is=${rollout_is} \
|
||||
algorithm.rollout_is_threshold=${rollout_is_threshold} \
|
||||
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
|
||||
algorithm.rollout_is_level=${rollout_is_level} \
|
||||
algorithm.rollout_is_mode=${rollout_is_mode} \
|
||||
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=${learning_rate} \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
|
||||
actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
trainer.logger='["console","wandb"]' \
|
||||
trainer.project_name="rollout_is_example" \
|
||||
trainer.experiment_name="basic_token_truncate" \
|
||||
trainer.total_epochs=10
|
||||
|
||||
echo "Training completed!"
|
||||
echo ""
|
||||
echo "Rollout IS Configuration:"
|
||||
echo " - Threshold: ${rollout_is_threshold}"
|
||||
echo " - Apply to loss: ${rollout_is}"
|
||||
echo " - Level: ${rollout_is_level}"
|
||||
echo " - Mode: ${rollout_is_mode}"
|
||||
echo ""
|
||||
echo "Monitor these key metrics in wandb:"
|
||||
echo " - mismatch/rollout_is_mean (should be ~1.0)"
|
||||
echo " - mismatch/rollout_is_eff_sample_size (should be >0.5)"
|
||||
echo " - mismatch/rollout_is_veto_fraction (should be <0.1)"
|
@ -0,0 +1,23 @@
|
||||
hydra:
|
||||
searchpath:
|
||||
- file://verl/trainer/config
|
||||
|
||||
defaults:
|
||||
- ppo_trainer
|
||||
- _self_
|
||||
|
||||
data:
|
||||
max_prompt_length: 1024
|
||||
max_response_length: 1024
|
||||
train_batch_size: 256
|
||||
return_raw_chat: True
|
||||
shuffle: False
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
rollout:
|
||||
name: sglang
|
||||
multi_turn:
|
||||
enable: True
|
||||
max_assistant_turns: 2
|
||||
format: qwen
|
@ -51,7 +51,7 @@ actor_rollout_ref:
|
||||
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
lr_scheduler_type: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
fsdp_config:
|
||||
wrap_policy:
|
||||
@ -105,7 +105,7 @@ critic:
|
||||
lr: 1e-5
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
lr_scheduler_type: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
|
@ -304,6 +304,11 @@ class RayDAPOTrainer(RayPPOTrainer):
|
||||
values = self.critic_wg.compute_values(batch)
|
||||
batch = batch.union(values)
|
||||
|
||||
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
|
||||
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
|
||||
# IS and mismatch metrics already have mismatch/ prefix
|
||||
metrics.update(is_metrics)
|
||||
|
||||
with marked_timer("adv", timing_raw, "brown"):
|
||||
# compute advantages, executed on the driver process
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
|
||||
|
@ -1,8 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
# Rollout Importance Sampling Example
|
||||
# References:
|
||||
# - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
# - Off-policy RL: https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
|
||||
exp_name='DAPO-Qwen2.5-32B-RolloutIS' # Rollout Importance Sampling
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
@ -10,7 +15,14 @@ use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
tis_imp_ratio_cap=2.0
|
||||
|
||||
# Rollout Importance Sampling parameters (matches original TIS with threshold=2)
|
||||
rollout_is=True
|
||||
rollout_is_threshold=2.0
|
||||
rollout_is_threshold_lower=null # No lower bound (original TIS behavior)
|
||||
rollout_is_level=token # token-level (original TIS behavior)
|
||||
rollout_is_mode=truncate # truncate mode (original TIS behavior)
|
||||
rollout_is_veto_threshold=null # No veto (original TIS behavior)
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
@ -58,14 +70,17 @@ offload=True
|
||||
gen_tp=4
|
||||
|
||||
|
||||
# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
|
||||
# so currently, server mode is not supported for TIS.
|
||||
|
||||
# To turn on TIS, you need to set the following parameters. Note 2.0 is a hyper-parameter and can be tuned.
|
||||
# actor_rollout_ref.actor.tis_imp_ratio_cap=2.0
|
||||
# actor_rollout_ref.rollout.calculate_log_probs=True
|
||||
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
|
||||
#
|
||||
# Please note that server mode (agent loop) hasn't returned rollout_log_probs for now,
|
||||
# so currently server mode is not supported for Rollout IS.
|
||||
#
|
||||
# Rollout IS parameters (configured at top of script):
|
||||
# algorithm.rollout_is=True
|
||||
# algorithm.rollout_is_threshold=2.0 # Upper threshold (can be tuned)
|
||||
# algorithm.rollout_is_level=token # Aggregation level
|
||||
# algorithm.rollout_is_mode=truncate # Bounding mode
|
||||
# actor_rollout_ref.rollout.calculate_log_probs=True # Required!
|
||||
|
||||
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
|
||||
--working-dir "${WORKING_DIR}" \
|
||||
@ -109,7 +124,12 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.tis_imp_ratio_cap=${tis_imp_ratio_cap} \
|
||||
algorithm.rollout_is=${rollout_is} \
|
||||
algorithm.rollout_is_threshold=${rollout_is_threshold} \
|
||||
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
|
||||
algorithm.rollout_is_level=${rollout_is_level} \
|
||||
algorithm.rollout_is_mode=${rollout_is_mode} \
|
||||
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
|
@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -100,7 +100,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -293,6 +293,6 @@ python3 -m recipe.one_step_off_policy.async_main_ppo \
|
||||
| Category | Support Situation |
|
||||
|--------------------|-----------------------------------------------------------------------------------------------------------------|
|
||||
| train engine | FSDP2 <br/> Megatron |
|
||||
| rollout engine | vLLM |
|
||||
| rollout engine | vLLM <br/> SGLang |
|
||||
| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |
|
||||
| Reward | all |
|
||||
|
140
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_4_12.sh
Normal file
140
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_4_12.sh
Normal file
@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-one-step-off-4-12'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
train_prompt_bsz=512
|
||||
n_resp_per_prompt=12
|
||||
train_prompt_mini_bsz=32
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-2}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
n_gpus_rollout=2
|
||||
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
|
||||
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=2
|
||||
sp_size=4
|
||||
fsdp_size=2
|
||||
|
||||
python3 -m recipe.one_step_off_policy.main_ppo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.layered_summon=True \
|
||||
actor_rollout_ref.rollout.load_format=safetensors \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.test_freq=10 \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.total_epochs=10 \
|
||||
trainer.total_training_steps=100 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.log_val_generations=10 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${n_gpus_training}" \
|
||||
rollout.nnodes="${NNODES}" \
|
||||
rollout.n_gpus_per_node="${n_gpus_rollout}"
|
133
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_colocate.sh
Normal file
133
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_colocate.sh
Normal file
@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-colocate'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
train_prompt_bsz=512
|
||||
n_resp_per_prompt=12
|
||||
train_prompt_mini_bsz=32
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-2}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
offload=True
|
||||
gen_tp=2
|
||||
sp_size=4
|
||||
fsdp_size=2
|
||||
|
||||
# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.layered_summon=True \
|
||||
actor_rollout_ref.rollout.load_format=safetensors \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.test_freq=10 \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.total_epochs=10 \
|
||||
trainer.total_training_steps=100 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.log_val_generations=10
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -83,13 +84,20 @@ class ActorRolloutRefWorker(ARRWorker):
|
||||
assert hasattr(self, "_weights_info") and self._weights_info is not None
|
||||
|
||||
params = self._get_actor_params() if self._is_actor else None
|
||||
rollout_name = self.config.rollout.name
|
||||
if self._is_rollout:
|
||||
inference_model = (
|
||||
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
)
|
||||
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
|
||||
if rollout_name == "vllm":
|
||||
inference_model = (
|
||||
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
)
|
||||
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
|
||||
|
||||
patch_vllm_moe_model_weight_loader(inference_model)
|
||||
patch_vllm_moe_model_weight_loader(inference_model)
|
||||
elif rollout_name == "sglang":
|
||||
inference_model = self.rollout._engine
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
|
||||
loop = asyncio.get_event_loop()
|
||||
for key, shape, dtype in self._weights_info:
|
||||
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
|
||||
if self._is_actor:
|
||||
@ -102,7 +110,23 @@ class ActorRolloutRefWorker(ARRWorker):
|
||||
|
||||
self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
|
||||
if self._is_rollout:
|
||||
inference_model.load_weights([(key, tensor)])
|
||||
if rollout_name == "vllm":
|
||||
inference_model.load_weights([(key, tensor)])
|
||||
elif rollout_name == "sglang":
|
||||
loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)]))
|
||||
|
||||
async def update_weights(self, inference_engine, params):
|
||||
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
|
||||
|
||||
await sgl_update_weights(
|
||||
engine=inference_engine,
|
||||
params_batch=params,
|
||||
device_mesh_key="infer_tp",
|
||||
device_mesh=self.rollout_device_mesh,
|
||||
)
|
||||
|
||||
if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
|
||||
await inference_engine.flush_cache()
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def get_actor_weights_info(self):
|
||||
@ -209,6 +233,7 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
rollout_device_mesh = init_device_mesh(
|
||||
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
|
||||
)
|
||||
self.rollout_device_mesh = rollout_device_mesh
|
||||
|
||||
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
|
||||
self._register_dispatch_collect_info(
|
||||
@ -216,7 +241,8 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
)
|
||||
|
||||
rollout_name = self.config.rollout.name
|
||||
assert rollout_name == "vllm"
|
||||
if rollout_name not in ("vllm", "sglang"):
|
||||
raise NotImplementedError(f"rollout_name: {rollout_name} is not supported")
|
||||
|
||||
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
|
||||
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)
|
||||
@ -227,14 +253,23 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
|
||||
)
|
||||
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
|
||||
from .vllm_sharding_manager import VLLMShardingManager
|
||||
|
||||
rollout_sharding_manager = VLLMShardingManager(
|
||||
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
|
||||
)
|
||||
if rollout_name == "vllm":
|
||||
from .vllm_sharding_manager import VLLMShardingManager
|
||||
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
rollout_sharding_manager = VLLMShardingManager(
|
||||
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
|
||||
)
|
||||
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
elif rollout_name == "sglang":
|
||||
from .sglang_sharding_manager import SGLangShardingManager
|
||||
|
||||
rollout_sharding_manager = SGLangShardingManager(device_mesh=rollout_device_mesh)
|
||||
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
|
||||
self.model_config = model_config
|
||||
self.rollout = rollout
|
||||
self.rollout_sharding_manager = rollout_sharding_manager
|
||||
|
||||
|
@ -0,0 +1,65 @@
|
||||
set -x
|
||||
|
||||
project_name='GRPO'
|
||||
exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-sglang-one-step-off-2-6'
|
||||
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-0.6B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
|
||||
|
||||
NNODES=${NNODES:-1}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
n_gpus_rollout=2
|
||||
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
|
||||
|
||||
|
||||
python3 -m recipe.one_step_off_policy.main_ppo \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.train_batch_size=1152 \
|
||||
data.max_prompt_length=512 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=192 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.n=5 \
|
||||
actor_rollout_ref.rollout.load_format=safetensors \
|
||||
actor_rollout_ref.rollout.layered_summon=True \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.val_before_train=True \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=2 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${n_gpus_training}" \
|
||||
rollout.nnodes="${NNODES}" \
|
||||
rollout.n_gpus_per_node="${n_gpus_rollout}" $@
|
@ -577,6 +577,11 @@ class OneStepOffRayTrainer(RayPPOTrainer):
|
||||
else:
|
||||
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
||||
|
||||
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
|
||||
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
|
||||
# IS and mismatch metrics already have mismatch/ prefix
|
||||
metrics.update(is_metrics)
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
||||
|
70
recipe/one_step_off_policy/sglang_sharding_manager.py
Normal file
70
recipe/one_step_off_policy/sglang_sharding_manager.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from verl import DataProto
|
||||
from verl.protocol import all_gather_data_proto
|
||||
from verl.utils.debug import GPUMemoryLogger
|
||||
from verl.utils.device import get_torch_device
|
||||
from verl.utils.torch_functional import check_device_is_available
|
||||
from verl.workers.sharding_manager.base import BaseShardingManager
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class SGLangShardingManager(BaseShardingManager):
|
||||
@check_device_is_available()
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
self.device_mesh = device_mesh
|
||||
self.tp_size = self.device_mesh["infer_tp"].size()
|
||||
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
|
||||
self.timing = {}
|
||||
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
|
||||
get_torch_device().manual_seed(gen_dp_rank + 1000)
|
||||
self.gen_random_states = get_torch_device().get_rng_state()
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def __enter__(self):
|
||||
get_torch_device().set_rng_state(self.gen_random_states)
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.gen_random_states = get_torch_device().get_rng_state()
|
||||
get_torch_device().empty_cache()
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""All gather across tp group to make each rank has identical input."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
group = self.device_mesh["infer_tp"].get_group()
|
||||
|
||||
all_gather_data_proto(data=data, process_group=group)
|
||||
return data
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""Get chunk data of this tp rank since we do all gather in preprocess."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
return data.chunk(chunks=self.tp_size)[self.tp_rank]
|
55
recipe/open_math_reasoning/README.md
Normal file
55
recipe/open_math_reasoning/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
# Open math reasoning
|
||||
## Introduction
|
||||
In this recipe, we perform SFT on the [open math reasoning](https://huggingface.co/datasets/nvidia/OpenMathReasoning) dataset using the new SFT trainer with backend agostic model engine. Note that our goal is not to replicate the [AIMO-2 Winning Solution](https://arxiv.org/abs/2504.16891) work, but to demonstrate a SFT demo from end to end.
|
||||
|
||||
Note that you may need to modify the path as needed in the following scripts.
|
||||
## Dataset Preprocessing
|
||||
### Download Dataset
|
||||
```bash
|
||||
hf download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* --local-dir /path/to/dataset/nvidia/OpenMathReasoning
|
||||
hf download math-ai/aime24 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime24
|
||||
hf download math-ai/aime25 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime25
|
||||
```
|
||||
|
||||
### Preprocess the dataset
|
||||
```bash
|
||||
python3 recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py --local_dataset_path /path/to/nvidia/OpenMathReasoning --local_save_dir /path/to/open_math_reasoning
|
||||
```
|
||||
|
||||
### Prepare the eval dataset
|
||||
```bash
|
||||
python3 recipe/open_math_reasoning/prepare_eval_dataset.py --local_dataset_path /path/to/dataset --local_save_dir /path/to/eval_dataset
|
||||
```
|
||||
|
||||
## Train the model using SFT
|
||||
### FSDP backend
|
||||
export CKPT_HOME=/path/to/ckpt
|
||||
export BACKEND=fsdp2
|
||||
export MODEL_ID=Qwen/Qwen3-8B-Base
|
||||
export TRAIN_FILES=/path/to/open_math_reasoning/cot_dataset.parquet
|
||||
bash recipe/open_math_reasoning/run_sft_qwen3_8b.sh
|
||||
|
||||
### Megatron backend
|
||||
TODO
|
||||
|
||||
## Eval the model
|
||||
### Merge checkpoint into huggingface format
|
||||
```bash
|
||||
python -m verl.model_merger merge --backend fsdp --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface
|
||||
```
|
||||
|
||||
### Generate the responses
|
||||
```bash
|
||||
export MODEL_PATH=/path/to/ckpt/global_step_19751/huggingface
|
||||
bash recipe/open_math_reasoning/run_generation.sh
|
||||
```
|
||||
|
||||
### Evaluate the responses
|
||||
```bash
|
||||
bash recipe/open_math_reasoning/run_eval.sh
|
||||
```
|
||||
|
||||
You should see the results like:
|
||||
```python
|
||||
{'test_score/aime24': 0.584375, 'test_score/aime25': 0.43333333333333335}
|
||||
```
|
22
recipe/open_math_reasoning/compute_score.py
Normal file
22
recipe/open_math_reasoning/compute_score.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def compute_score_data_source(data_source, response, ground_truth):
|
||||
from verl.utils.reward_score.math_reward import compute_score
|
||||
|
||||
if data_source in ["aime24", "aime25"]:
|
||||
return compute_score(response, ground_truth)
|
||||
else:
|
||||
raise ValueError(f"Unknown data source: {data_source}")
|
96
recipe/open_math_reasoning/prepare_eval_dataset.py
Normal file
96
recipe/open_math_reasoning/prepare_eval_dataset.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# prepare eval dataset including AIME'24, AIME'25
|
||||
|
||||
# hf download math-ai/aime24 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime24
|
||||
# hf download math-ai/aime25 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime25
|
||||
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
||||
from verl.utils.reward_score.math_reward import remove_boxed
|
||||
|
||||
instruction_following = "Please reason step by step, and put your final answer within \\boxed{}."
|
||||
|
||||
|
||||
def make_map_fn(data_source):
|
||||
def process_fn(example, idx):
|
||||
question_raw = example.pop("problem")
|
||||
|
||||
question = question_raw + " " + instruction_following
|
||||
|
||||
if "solution" not in example:
|
||||
example["solution"] = example["answer"]
|
||||
|
||||
answer_raw = example.pop("solution")
|
||||
|
||||
example.clear()
|
||||
|
||||
try:
|
||||
solution = remove_boxed(answer_raw)
|
||||
except Exception:
|
||||
solution = answer_raw
|
||||
|
||||
data = {
|
||||
"data_source": data_source,
|
||||
"prompt": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
}
|
||||
],
|
||||
"ability": "math",
|
||||
"reward_model": {"style": "rule", "ground_truth": solution},
|
||||
"extra_info": {
|
||||
"index": idx,
|
||||
"answer": answer_raw,
|
||||
"question": question_raw,
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
return process_fn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
|
||||
parser.add_argument(
|
||||
"--local_save_dir", default="~/data/math-ai", help="The save directory for the preprocessed dataset."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.local_dataset_path is not None:
|
||||
aime24_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime24")
|
||||
aime25_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime25")
|
||||
else:
|
||||
aime24_dataset_path = "math-ai/aime24"
|
||||
aime25_dataset_path = "math-ai/aime25"
|
||||
|
||||
aime24_dataset = datasets.load_dataset(aime24_dataset_path, split="test")
|
||||
aime25_dataset = datasets.load_dataset(aime25_dataset_path, split="test")
|
||||
|
||||
aime24_dataset = aime24_dataset.map(function=make_map_fn("aime24"), with_indices=True)
|
||||
aime25_dataset = aime25_dataset.map(function=make_map_fn("aime25"), with_indices=True)
|
||||
|
||||
local_save_dir = os.path.expanduser(args.local_save_dir)
|
||||
os.makedirs(local_save_dir, exist_ok=True)
|
||||
|
||||
aime24_dataset.to_parquet(os.path.join(local_save_dir, "aime24_test.parquet"))
|
||||
aime25_dataset.to_parquet(os.path.join(local_save_dir, "aime25_test.parquet"))
|
@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
|
||||
--local-dir /path/to/nvidia/OpenMathReasoning
|
||||
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
|
||||
--local-dir /opt/tiger/nvidia/OpenMathReasoning
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
|
||||
parser.add_argument(
|
||||
"--local_save_dir",
|
||||
default="~/data/open_math_reasoning",
|
||||
help="The save directory for the preprocessed dataset.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
local_dataset_path = args.local_dataset_path
|
||||
|
||||
data_source = "nvidia/OpenMathReasoning"
|
||||
|
||||
if local_dataset_path is not None:
|
||||
dataset = datasets.load_dataset(local_dataset_path, split="cot")
|
||||
else:
|
||||
dataset = datasets.load_dataset(data_source, split="cot")
|
||||
|
||||
def make_map_fn(split):
|
||||
def process_fn(example, idx):
|
||||
question = example.pop("problem")
|
||||
solution = example.pop("generated_solution")
|
||||
|
||||
extra_info = {}
|
||||
for key, value in example.items():
|
||||
extra_info[key] = value
|
||||
example.clear()
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "user", "content": question, "loss_mask": 0},
|
||||
{"role": "assistant", "content": solution, "loss_mask": 1},
|
||||
],
|
||||
"extra_info": extra_info,
|
||||
}
|
||||
return data
|
||||
|
||||
return process_fn
|
||||
|
||||
# filter out data where the problem_type is not has_answer_extracted
|
||||
dataset = dataset.filter(lambda example: example["problem_type"] == "has_answer_extracted")
|
||||
dataset = dataset.map(function=make_map_fn("cot"), with_indices=True)
|
||||
local_save_dir = os.path.expanduser(args.local_save_dir)
|
||||
os.makedirs(local_save_dir, exist_ok=True)
|
||||
dataset.to_parquet(os.path.join(local_save_dir, "cot_dataset.parquet"))
|
7
recipe/open_math_reasoning/run_eval.sh
Normal file
7
recipe/open_math_reasoning/run_eval.sh
Normal file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Evaluation
|
||||
python3 -m verl.trainer.main_eval \
|
||||
data.path=$HOME/data/gen/qwen_8b_gen_test.parquet \
|
||||
custom_reward_function.path=recipe/open_math_reasoning/compute_score.py \
|
||||
custom_reward_function.name=compute_score_data_source
|
32
recipe/open_math_reasoning/run_generation.sh
Normal file
32
recipe/open_math_reasoning/run_generation.sh
Normal file
@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
MODEL_PATH=${MODEL_PATH:-/path/to/ckpt/global_step_19751/huggingface}
|
||||
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
NNODES=${NNODES:-1}
|
||||
OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_8b_gen_test.parquet}
|
||||
GEN_TP=${GEN_TP:-1} # Default tensor parallel size to 2
|
||||
|
||||
aime24_test_path=${HOME}/data/math-ai/aime24_test.parquet
|
||||
aime25_test_path=${HOME}/data/math-ai/aime25_test.parquet
|
||||
train_files="['$aime24_test_path', '$aime25_test_path']"
|
||||
|
||||
python3 -m verl.trainer.main_generation_server \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.trust_remote_code=True \
|
||||
actor_rollout_ref.rollout.temperature=1.0 \
|
||||
actor_rollout_ref.rollout.top_p=0.7 \
|
||||
actor_rollout_ref.rollout.prompt_length=2048 \
|
||||
actor_rollout_ref.rollout.response_length=20480 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.n=32 \
|
||||
data.train_files="$train_files" \
|
||||
data.prompt_key=prompt \
|
||||
+data.output_path="${OUTPUT_PATH}" \
|
||||
|
||||
|
||||
|
94
recipe/open_math_reasoning/run_sft_qwen3_8b.sh
Normal file
94
recipe/open_math_reasoning/run_sft_qwen3_8b.sh
Normal file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
|
||||
|
||||
TRAIN_FILES=${TRAIN_FILES:-/path/to/cot_dataset.parquet}
|
||||
|
||||
backend=${BACKEND:-fsdp}
|
||||
|
||||
project_name=verl_sft_test
|
||||
|
||||
RESUME_MODE=auto
|
||||
MODEL_ID=${MODEL_ID:-Qwen/Qwen3-8B-Base}
|
||||
|
||||
SP_SIZE=${SP_SIZE:-8}
|
||||
FSDP_SIZE=${FSDP_SIZE:-16}
|
||||
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"}
|
||||
|
||||
TP_SIZE=${TP_SIZE:-1}
|
||||
PP_SIZE=${PP_SIZE:-1}
|
||||
VPP_SIZE=${VPP_SIZE:-null}
|
||||
CP_SIZE=${CP_SIZE:-1}
|
||||
|
||||
PAD_MODE=${PAD_MODE:-no_padding}
|
||||
|
||||
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
|
||||
|
||||
FSDP_ENGINE_CONFIG="\
|
||||
engine=${backend} \
|
||||
optim=${backend} \
|
||||
optim.lr=2e-5 \
|
||||
optim.lr_warmup_steps_ratio=0.01 \
|
||||
optim.weight_decay=0.1 \
|
||||
optim.betas="[0.9,0.95]" \
|
||||
optim.clip_grad=1.0 \
|
||||
optim.min_lr_ratio=0.1 \
|
||||
optim.warmup_style=cosine \
|
||||
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
|
||||
engine.strategy=${FSDP_STRATEGY} \
|
||||
engine.fsdp_size=${FSDP_SIZE}"
|
||||
|
||||
|
||||
MEGATRON_ENGINE_CONFIG="\
|
||||
engine=${backend} \
|
||||
optim=${backend} \
|
||||
optim.lr=1e-5 \
|
||||
optim.lr_warmup_steps_ratio=0.2 \
|
||||
optim.weight_decay=0.1 \
|
||||
optim.betas="[0.9,0.95]" \
|
||||
optim.clip_grad=1.0 \
|
||||
optim.lr_warmup_init=0 \
|
||||
optim.lr_decay_style=cosine \
|
||||
optim.min_lr=1e-6 \
|
||||
engine.tensor_model_parallel_size=${TP_SIZE} \
|
||||
engine.pipeline_model_parallel_size=${PP_SIZE} \
|
||||
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
|
||||
engine.context_parallel_size=${CP_SIZE}"
|
||||
|
||||
if [ "$backend" = "fsdp" ]; then
|
||||
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
|
||||
echo "Using fsdp engine"
|
||||
exp_name=nvidia-openmathreasoning-qwen3-8b-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp-1008a1
|
||||
else
|
||||
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
|
||||
echo "Using megatron engine"
|
||||
exp_name=nvidia-openmathreasoning-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
|
||||
fi
|
||||
|
||||
CKPT_HOME=${CKPT_HOME:-$HOME/open_verl/sft/${project_name}/${exp_name}}
|
||||
mkdir -p "${CKPT_HOME}"
|
||||
|
||||
torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-8} \
|
||||
${ENTRYPOINT} \
|
||||
data.train_files="${TRAIN_FILES}" \
|
||||
data.train_batch_size=96 \
|
||||
data.max_length=32768 \
|
||||
data.pad_mode=${PAD_MODE} \
|
||||
data.truncation=error \
|
||||
data.use_dynamic_bsz=True \
|
||||
data.max_token_len_per_gpu=65536 \
|
||||
data.messages_key=messages \
|
||||
model.path=$MODEL_ID \
|
||||
model.use_remove_padding=${USE_REMOVE_PADDING} \
|
||||
${ENGINE_CONFIG} \
|
||||
trainer.test_freq=-1 \
|
||||
trainer.save_freq=4000 \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.total_epochs=1 \
|
||||
trainer.default_local_dir="${CKPT_HOME}" \
|
||||
trainer.resume_mode=${RESUME_MODE} \
|
||||
trainer.max_ckpt_to_keep=5 \
|
||||
checkpoint.save_contents=[model,optimizer,extra]
|
@ -48,7 +48,8 @@ reward_model:
|
||||
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null
|
||||
warmup_style: constant
|
||||
warmup_style: null # deprecated
|
||||
lr_scheduler_type: constant
|
||||
total_training_steps: -1 # must be overridden by program
|
||||
weight_decay: 0.
|
||||
grad_clip: 10.0
|
||||
|
@ -24,7 +24,7 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig
|
||||
|
||||
from verl import DataProto
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
||||
@ -289,8 +289,9 @@ def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, mod
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
ref_model_config = AutoConfig.from_pretrained(model_path)
|
||||
with torch.device("meta"):
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
ref_model = AutoModelForCausalLM.from_config(ref_model_config)
|
||||
|
||||
from verl.workers.engine import BaseEngine, EngineRegistry
|
||||
|
||||
|
@ -42,7 +42,7 @@ FSDP_ENGINE_CONFIG="\
|
||||
optim.betas="[0.9,0.95]" \
|
||||
optim.clip_grad=1.0 \
|
||||
optim.min_lr_ratio=0.1 \
|
||||
optim.warmup_style=cosine \
|
||||
optim.lr_scheduler_type=cosine \
|
||||
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
|
||||
engine.strategy=${FSDP_STRATEGY} \
|
||||
engine.fsdp_size=${FSDP_SIZE}"
|
||||
|
@ -301,8 +301,8 @@ actor_rollout_ref:
|
||||
# Number of cosine cycles in LR schedule
|
||||
num_cycles: 0.5
|
||||
|
||||
# LR warmup style: "constant" or "cosine"
|
||||
warmup_style: constant
|
||||
# LR scheduler type: "constant" or "cosine"
|
||||
lr_scheduler_type: constant
|
||||
|
||||
# Total training steps (must be overridden at runtime)
|
||||
total_training_steps: -1
|
||||
@ -605,8 +605,8 @@ critic:
|
||||
# Minimum LR ratio for cosine schedule
|
||||
min_lr_ratio: 0.0
|
||||
|
||||
# LR warmup style: "constant" or "cosine"
|
||||
warmup_style: constant
|
||||
# LR scheduler type: "constant" or "cosine"
|
||||
lr_scheduler_type: constant
|
||||
|
||||
# Total training steps (must be overridden at runtime)
|
||||
total_training_steps: -1
|
||||
|
289
tests/trainer/ppo/test_rollout_is.py
Normal file
289
tests/trainer/ppo/test_rollout_is.py
Normal file
@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Quick Sanity Test for Rollout Importance Sampling
|
||||
|
||||
This is a standalone test script that can be run without pytest to quickly verify
|
||||
the rollout IS implementation is working correctly. For comprehensive integration
|
||||
tests, see: tests/trainer/ppo/test_rollout_is_integration.py
|
||||
|
||||
Usage:
|
||||
python test_rollout_is.py
|
||||
|
||||
This tests:
|
||||
- Basic rollout IS functionality (3 levels, 2 modes)
|
||||
- Metrics completeness (32 total: 21 IS + 11 mismatch metrics)
|
||||
- Veto mechanism
|
||||
- Edge cases
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
|
||||
|
||||
|
||||
def test_basic_rollout_is():
|
||||
"""Test basic rollout IS functionality."""
|
||||
print("Testing basic rollout IS functionality...")
|
||||
|
||||
# Create test data
|
||||
batch_size, seq_length = 4, 10
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Create slightly different log probs (simulating BF16 vs FP32 mismatch)
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
||||
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
|
||||
eos_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
# Test token-level truncate mode (equivalent to old TIS)
|
||||
print("\n1. Testing token-level truncate mode...")
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
weights = weights_proto.batch["rollout_is_weights"]
|
||||
print(f" Weights shape: {weights.shape}")
|
||||
print(f" Mean weight: {metrics['mismatch/rollout_is_mean']:.4f}")
|
||||
print(f" Max weight: {metrics['mismatch/rollout_is_max']:.4f}")
|
||||
print(f" Min weight: {metrics['mismatch/rollout_is_min']:.4f}")
|
||||
print(f" Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.4f}")
|
||||
assert weights.shape == old_log_prob.shape
|
||||
assert weights.max() <= 2.0, "Weights should be capped at threshold"
|
||||
print(" ✓ Token-level truncate mode passed")
|
||||
|
||||
# Test sequence-level mode
|
||||
print("\n2. Testing sequence-level mode...")
|
||||
weights_seq_proto, metrics_seq = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="sequence",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=5.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
weights_seq = weights_seq_proto.batch["rollout_is_weights"]
|
||||
print(f" Mean weight: {metrics_seq['mismatch/rollout_is_mean']:.4f}")
|
||||
print(f" Effective sample size: {metrics_seq['mismatch/rollout_is_eff_sample_size']:.4f}")
|
||||
# Check that all tokens in a sequence have the same weight
|
||||
for i in range(batch_size):
|
||||
seq_weights = weights_seq[i, eos_mask[i].bool()]
|
||||
assert torch.allclose(seq_weights, seq_weights[0]), "All tokens in sequence should have same weight"
|
||||
print(" ✓ Sequence-level mode passed")
|
||||
|
||||
# Test geometric mean mode
|
||||
print("\n3. Testing geometric mean mode...")
|
||||
weights_geo_proto, metrics_geo = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="geometric",
|
||||
rollout_is_mode="mask",
|
||||
rollout_is_threshold=1.5,
|
||||
rollout_is_threshold_lower=0.5,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
|
||||
print(f" Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}")
|
||||
print(" ✓ Geometric mean mode passed")
|
||||
|
||||
# Test veto mechanism
|
||||
print("\n4. Testing veto mechanism...")
|
||||
# Create data with catastrophic outliers
|
||||
old_log_prob_veto = torch.randn(2, 5, device=device)
|
||||
rollout_log_prob_veto = old_log_prob_veto.clone()
|
||||
# Make one token have catastrophically low ratio
|
||||
rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7
|
||||
eos_mask_veto = torch.ones(2, 5, device=device)
|
||||
|
||||
weights_veto_proto, metrics_veto = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob_veto,
|
||||
rollout_log_prob=rollout_log_prob_veto,
|
||||
response_mask=eos_mask_veto,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
weights_veto = weights_veto_proto.batch["rollout_is_weights"]
|
||||
print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}")
|
||||
# Check that the sequence with catastrophic token has all weights zeroed
|
||||
assert weights_veto[0].sum() == 0, "Sequence with catastrophic token should be vetoed"
|
||||
assert weights_veto[1].sum() > 0, "Normal sequence should not be vetoed"
|
||||
print(" ✓ Veto mechanism passed")
|
||||
|
||||
# Test disabled IS (threshold=None)
|
||||
print("\n5. Testing disabled IS...")
|
||||
weights_disabled, metrics_disabled = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_threshold=None,
|
||||
)
|
||||
|
||||
assert weights_disabled is None, "Should return None when threshold is None"
|
||||
assert len(metrics_disabled) == 0, "Should return empty metrics when disabled"
|
||||
print(" ✓ Disabled IS passed")
|
||||
|
||||
print("\n✓ All tests passed!")
|
||||
|
||||
|
||||
def test_metrics_completeness():
|
||||
"""Test that all expected metrics are returned."""
|
||||
print("\nTesting metrics completeness...")
|
||||
|
||||
batch_size, seq_length = 3, 8
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
||||
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
|
||||
eos_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.5,
|
||||
)
|
||||
|
||||
# Expected IS metrics
|
||||
expected_is_metrics = [
|
||||
"mismatch/rollout_is_mean",
|
||||
"mismatch/rollout_is_max",
|
||||
"mismatch/rollout_is_min",
|
||||
"mismatch/rollout_is_std",
|
||||
"mismatch/rollout_is_eff_sample_size",
|
||||
"mismatch/rollout_is_veto_fraction",
|
||||
"mismatch/rollout_is_catastrophic_token_fraction",
|
||||
"mismatch/rollout_is_ratio_fraction_high",
|
||||
"mismatch/rollout_is_ratio_fraction_low",
|
||||
"mismatch/rollout_is_p25",
|
||||
"mismatch/rollout_is_p50",
|
||||
"mismatch/rollout_is_p75",
|
||||
"mismatch/rollout_is_p95",
|
||||
"mismatch/rollout_is_p99",
|
||||
]
|
||||
|
||||
# Expected mismatch/diagnostic metrics (also included now)
|
||||
expected_mismatch_metrics = [
|
||||
"mismatch/mismatch_training_ppl",
|
||||
"mismatch/mismatch_training_log_ppl",
|
||||
"mismatch/mismatch_kl",
|
||||
"mismatch/mismatch_k3_kl",
|
||||
"mismatch/mismatch_rollout_ppl",
|
||||
"mismatch/mismatch_rollout_log_ppl",
|
||||
"mismatch/mismatch_log_ppl_diff",
|
||||
"mismatch/mismatch_log_ppl_abs_diff",
|
||||
"mismatch/mismatch_log_ppl_diff_max",
|
||||
"mismatch/mismatch_log_ppl_diff_min",
|
||||
"mismatch/mismatch_ppl_ratio",
|
||||
]
|
||||
|
||||
expected_metrics = expected_is_metrics + expected_mismatch_metrics
|
||||
|
||||
missing_metrics = [m for m in expected_metrics if m not in metrics]
|
||||
if missing_metrics:
|
||||
print(f" ✗ Missing metrics: {missing_metrics}")
|
||||
return False
|
||||
|
||||
print(f" ✓ All {len(expected_metrics)} expected metrics present")
|
||||
print(f" Total metrics returned: {len(metrics)}")
|
||||
return True
|
||||
|
||||
|
||||
def test_mismatch_metrics():
|
||||
"""Test mismatch metrics computation."""
|
||||
print("\nTesting mismatch metrics computation...")
|
||||
|
||||
batch_size, seq_length = 4, 12
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Create test data with some mismatch
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device) - 2.0 # training policy
|
||||
rollout_log_prob = torch.randn(batch_size, seq_length, device=device) - 1.5 # rollout policy (more confident)
|
||||
response_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
# Test with rollout log probs
|
||||
metrics = compute_mismatch_metrics(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=response_mask,
|
||||
)
|
||||
|
||||
expected_metrics = [
|
||||
"mismatch_training_ppl",
|
||||
"mismatch_training_log_ppl",
|
||||
"mismatch_kl",
|
||||
"mismatch_k3_kl",
|
||||
"mismatch_rollout_ppl",
|
||||
"mismatch_rollout_log_ppl",
|
||||
"mismatch_log_ppl_diff",
|
||||
"mismatch_log_ppl_abs_diff",
|
||||
"mismatch_log_ppl_diff_max",
|
||||
"mismatch_log_ppl_diff_min",
|
||||
"mismatch_ppl_ratio",
|
||||
]
|
||||
|
||||
for metric in expected_metrics:
|
||||
assert metric in metrics, f"Missing metric: {metric}"
|
||||
|
||||
print(f" Training PPL: {metrics['mismatch_training_ppl']:.4f}")
|
||||
print(f" Rollout PPL: {metrics['mismatch_rollout_ppl']:.4f}")
|
||||
print(f" KL divergence: {metrics['mismatch_kl']:.6f}")
|
||||
print(f" K3 KL: {metrics['mismatch_k3_kl']:.6f}")
|
||||
print(f" PPL ratio: {metrics['mismatch_ppl_ratio']:.4f}")
|
||||
print(f" ✓ All {len(expected_metrics)} mismatch metrics present")
|
||||
|
||||
# Test without rollout log probs
|
||||
metrics_no_rollout = compute_mismatch_metrics(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=None,
|
||||
response_mask=response_mask,
|
||||
)
|
||||
|
||||
assert "mismatch_training_ppl" in metrics_no_rollout
|
||||
assert "mismatch_rollout_ppl" not in metrics_no_rollout
|
||||
print(" ✓ Mismatch metrics work without rollout log probs")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Rollout Importance Sampling Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_basic_rollout_is()
|
||||
test_metrics_completeness()
|
||||
test_mismatch_metrics()
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL TESTS PASSED ✓")
|
||||
print("=" * 60)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit(1)
|
241
tests/trainer/ppo/test_rollout_is_integration.py
Normal file
241
tests/trainer/ppo/test_rollout_is_integration.py
Normal file
@ -0,0 +1,241 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Integration tests for Rollout Importance Sampling."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla
|
||||
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
|
||||
from verl.workers.config.actor import ActorConfig
|
||||
|
||||
|
||||
class TestRolloutISIntegration:
|
||||
"""Integration tests for Rollout IS with PPO."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample training data."""
|
||||
batch_size, seq_length = 4, 16
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
return {
|
||||
"old_log_prob": torch.randn(batch_size, seq_length, device=device),
|
||||
"log_prob": torch.randn(batch_size, seq_length, device=device),
|
||||
"rollout_log_prob": torch.randn(batch_size, seq_length, device=device),
|
||||
"advantages": torch.randn(batch_size, seq_length, device=device),
|
||||
"response_mask": torch.ones(batch_size, seq_length, device=device),
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def config_with_rollout_is(self):
|
||||
"""Create config for policy loss computation.
|
||||
|
||||
Note: rollout_is config has been moved to algorithm config.
|
||||
This config only needs fields used by policy loss (clip_ratio, etc).
|
||||
"""
|
||||
config = ActorConfig(
|
||||
strategy="fsdp",
|
||||
rollout_n=1,
|
||||
ppo_micro_batch_size=2,
|
||||
clip_ratio=0.2,
|
||||
)
|
||||
return config
|
||||
|
||||
def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
|
||||
"""Test that policy loss computation works with rollout IS weights.
|
||||
|
||||
Note: In production, IS weights are computed centrally in the trainer
|
||||
(before advantage computation) and passed to policy loss.
|
||||
This test simulates that workflow.
|
||||
"""
|
||||
# First compute IS weights (as trainer would do centrally)
|
||||
rollout_is_weights_proto, _ = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
|
||||
|
||||
# Policy loss function receives pre-computed IS weights
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss_vanilla(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
log_prob=sample_data["log_prob"],
|
||||
advantages=sample_data["advantages"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
loss_agg_mode="token-mean",
|
||||
config=config_with_rollout_is,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
# Check loss is valid
|
||||
assert isinstance(pg_loss, torch.Tensor)
|
||||
assert pg_loss.ndim == 0 # Scalar
|
||||
assert not torch.isnan(pg_loss)
|
||||
assert not torch.isinf(pg_loss)
|
||||
|
||||
def test_rollout_is_weights_computation(self, sample_data):
|
||||
"""Test rollout IS weights and metrics computation."""
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
# Check weights
|
||||
from verl.protocol import DataProto
|
||||
|
||||
assert isinstance(weights_proto, DataProto)
|
||||
weights = weights_proto.batch["rollout_is_weights"]
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == sample_data["old_log_prob"].shape
|
||||
|
||||
# Check metrics are returned
|
||||
assert isinstance(metrics, dict)
|
||||
assert len(metrics) > 0
|
||||
assert "mismatch/rollout_is_mean" in metrics
|
||||
|
||||
def test_all_aggregation_levels(self, sample_data):
|
||||
"""Test all three aggregation levels."""
|
||||
levels = ["token", "sequence", "geometric"]
|
||||
|
||||
for level in levels:
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level=level,
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
)
|
||||
|
||||
assert "mismatch/rollout_is_mean" in metrics
|
||||
|
||||
def test_both_bounding_modes(self, sample_data):
|
||||
"""Test both truncate and mask modes."""
|
||||
modes = ["truncate", "mask"]
|
||||
|
||||
for mode in modes:
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode=mode,
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_threshold_lower=0.5,
|
||||
)
|
||||
|
||||
assert "mismatch/rollout_is_mean" in metrics
|
||||
|
||||
def test_mismatch_metrics(self, sample_data):
|
||||
"""Test mismatch diagnostic metrics computation."""
|
||||
metrics = compute_mismatch_metrics(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
)
|
||||
|
||||
# Check key metrics are present
|
||||
assert "mismatch_training_ppl" in metrics
|
||||
assert "mismatch_rollout_ppl" in metrics
|
||||
assert "mismatch_kl" in metrics
|
||||
assert isinstance(metrics["mismatch_kl"], float)
|
||||
|
||||
def test_veto_mechanism(self):
|
||||
"""Test veto mechanism with catastrophic outliers."""
|
||||
batch_size, seq_length = 2, 5
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
||||
rollout_log_prob = old_log_prob.clone()
|
||||
|
||||
# Create catastrophic outlier in first sequence
|
||||
rollout_log_prob[0, 2] += 15.0 # Makes ratio ~3e-7
|
||||
|
||||
response_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=response_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
# Should have vetoed one sequence
|
||||
assert metrics["mismatch/rollout_is_veto_fraction"] > 0
|
||||
assert metrics["mismatch/rollout_is_veto_fraction"] <= 1.0
|
||||
|
||||
def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
|
||||
"""Test metrics-only mode: compute IS weights/metrics but don't apply to loss.
|
||||
|
||||
This tests the use case where rollout_is_threshold is set (enables computation)
|
||||
but rollout_is=False (disables weight application to policy loss).
|
||||
"""
|
||||
# Compute IS weights (as trainer would do)
|
||||
rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
)
|
||||
|
||||
# Metrics should be computed
|
||||
assert len(is_metrics) > 0
|
||||
assert "mismatch/rollout_is_mean" in is_metrics
|
||||
|
||||
# In metrics-only mode, we compute loss WITHOUT applying weights
|
||||
# (simulating rollout_is=False)
|
||||
pg_loss_no_weights, _, _, _ = compute_policy_loss_vanilla(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
log_prob=sample_data["log_prob"],
|
||||
advantages=sample_data["advantages"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
loss_agg_mode="token-mean",
|
||||
config=config_with_rollout_is,
|
||||
rollout_is_weights=None, # Don't apply weights
|
||||
)
|
||||
|
||||
# Compare to loss WITH weights (rollout_is=True)
|
||||
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
|
||||
pg_loss_with_weights, _, _, _ = compute_policy_loss_vanilla(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
log_prob=sample_data["log_prob"],
|
||||
advantages=sample_data["advantages"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
loss_agg_mode="token-mean",
|
||||
config=config_with_rollout_is,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
# Losses should be different (weights have an effect)
|
||||
assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
@ -21,15 +21,24 @@ class TestFSDPOptimizerConfigCPU:
|
||||
def test_default_configuration(self):
|
||||
config = FSDPOptimizerConfig(lr=0.1)
|
||||
assert config.min_lr_ratio is None
|
||||
assert config.warmup_style == "constant"
|
||||
assert config.lr_scheduler_type == "constant"
|
||||
assert config.num_cycles == 0.5
|
||||
|
||||
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
|
||||
def test_valid_warmup_styles(self, warmup_style):
|
||||
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
|
||||
assert config.warmup_style == warmup_style
|
||||
@pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"])
|
||||
def test_valid_lr_scheduler_types(self, lr_scheduler_type):
|
||||
config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1)
|
||||
assert config.lr_scheduler_type == lr_scheduler_type
|
||||
|
||||
def test_invalid_warmup_style(self):
|
||||
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
|
||||
def test_valid_warmup_style_types(self, warmup_style):
|
||||
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
|
||||
assert config.lr_scheduler_type == warmup_style
|
||||
|
||||
def test_invalid_lr_scheduler_type(self):
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1)
|
||||
|
||||
def test_invalid_warmup_style_type(self):
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1)
|
||||
|
||||
|
@ -113,7 +113,7 @@ def gptmodel_forward_qwen2_5_vl(
|
||||
output_orig = model(
|
||||
input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
position_ids=None, # model will calculate position_ids
|
||||
packed_seq_params=packed_seq_params,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
|
@ -74,6 +74,7 @@ class SupportedModel(Enum):
|
||||
GLM4_MOE = "Glm4MoeForCausalLM"
|
||||
|
||||
QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification"
|
||||
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
|
||||
|
||||
|
||||
# Registry for model configuration converters
|
||||
@ -118,6 +119,7 @@ MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.QWEN3: gptmodel_forward,
|
||||
SupportedModel.QWEN3_MOE: gptmodel_forward,
|
||||
SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
|
||||
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_qwen2_5_vl,
|
||||
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
|
||||
SupportedModel.GLM4_MOE: gptmodel_forward,
|
||||
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
|
||||
@ -131,6 +133,7 @@ MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.MIXTRAL: gptmodel_forward_no_padding,
|
||||
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_no_padding,
|
||||
SupportedModel.LLAMA4: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,
|
||||
@ -148,6 +151,7 @@ MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.MIXTRAL: fused_forward_gptmodel,
|
||||
SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
|
||||
SupportedModel.QWEN3_MOE_VL: fused_forward_qwen2_5_vl,
|
||||
SupportedModel.LLAMA4: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN3: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN3_MOE: fused_forward_gptmodel,
|
||||
|
@ -127,6 +127,8 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
|
||||
inputs_embeds = kwargs.get("inputs_embeds")
|
||||
position_ids = kwargs.get("position_ids")
|
||||
visual_pos_masks = kwargs.get("visual_pos_masks")
|
||||
deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds")
|
||||
call_kwargs = kwargs.copy()
|
||||
|
||||
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
@ -139,6 +141,43 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
if slice_now:
|
||||
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
|
||||
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
|
||||
# Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
|
||||
if visual_pos_masks is not None:
|
||||
original_visual_mask = visual_pos_masks
|
||||
sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)
|
||||
call_kwargs["visual_pos_masks"] = sliced_visual_mask
|
||||
|
||||
if deepstack_visual_embeds is not None:
|
||||
sliced_embeds = []
|
||||
|
||||
num_visual_before = original_visual_mask.sum().item()
|
||||
num_visual_in_shard = sliced_visual_mask.sum().item()
|
||||
|
||||
if num_visual_in_shard > 0 and num_visual_before > 0:
|
||||
# Calculate which visual embeddings belong to this shard
|
||||
# We need to find the offset of visual tokens in this shard
|
||||
from verl.utils.ulysses import get_ulysses_sequence_parallel_rank
|
||||
|
||||
rank = get_ulysses_sequence_parallel_rank()
|
||||
seq_len = original_visual_mask.shape[1]
|
||||
local_seq_len = seq_len // current_ulysses_sp_size
|
||||
start_idx = rank * local_seq_len
|
||||
end_idx = start_idx + local_seq_len
|
||||
|
||||
# Get total visual tokens before and up to the end of the shard's sequence slice
|
||||
# This correctly handles batches by summing across all samples
|
||||
visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0
|
||||
visual_end = original_visual_mask[:, :end_idx].sum().item()
|
||||
|
||||
# Slice each tensor in deepstack_visual_embeds
|
||||
for embed in deepstack_visual_embeds:
|
||||
sliced_embeds.append(embed[visual_start:visual_end])
|
||||
else:
|
||||
# No visual tokens in this shard, create empty tensors to maintain gradient flow
|
||||
for embed in deepstack_visual_embeds:
|
||||
sliced_embeds.append(embed[:0])
|
||||
call_kwargs["deepstack_visual_embeds"] = sliced_embeds
|
||||
|
||||
self._needs_initial_slice = False
|
||||
try:
|
||||
return original_forward(self, *args, **call_kwargs)
|
||||
@ -290,9 +329,7 @@ def apply_monkey_patch(
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
|
||||
)
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLFlashAttention2 as Qwen2VLAttention,
|
||||
)
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
|
||||
|
||||
if use_remove_padding or ulysses_sp_size > 1:
|
||||
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
|
||||
|
@ -209,8 +209,10 @@ def _get_input_embeds(
|
||||
patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2
|
||||
pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
||||
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
|
||||
image_embeds, _ = model.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_embeds, dummy_deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
inputs_embeds += 0.0 * image_embeds.mean()
|
||||
for emb in dummy_deepstack_image_embeds or []:
|
||||
inputs_embeds += 0.0 * emb.mean()
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
@ -75,7 +75,6 @@ actor_rollout_ref:
|
||||
clip_ratio_c: 3.0
|
||||
loss_agg_mode: token-mean
|
||||
entropy_coeff: 0
|
||||
tis_imp_ratio_cap: -1
|
||||
use_kl_loss: false
|
||||
use_torch_compile: true
|
||||
kl_loss_coef: 0.001
|
||||
@ -484,6 +483,12 @@ algorithm:
|
||||
pf_ppo:
|
||||
reweight_method: pow
|
||||
weight_pow: 2.0
|
||||
rollout_is_threshold: null
|
||||
rollout_is_threshold_lower: null
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 0.0001
|
||||
rollout_is: false
|
||||
trainer:
|
||||
balance_batch: true
|
||||
total_epochs: 30
|
||||
|
@ -18,7 +18,8 @@ actor_rollout_ref:
|
||||
clip_grad: 1.0
|
||||
min_lr_ratio: 0.0
|
||||
num_cycles: 0.5
|
||||
warmup_style: constant
|
||||
lr_scheduler_type: constant
|
||||
warmup_style: null
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
wrap_policy:
|
||||
@ -59,7 +60,6 @@ actor_rollout_ref:
|
||||
clip_ratio_c: 3.0
|
||||
loss_agg_mode: token-mean
|
||||
entropy_coeff: 0
|
||||
tis_imp_ratio_cap: -1
|
||||
use_kl_loss: false
|
||||
use_torch_compile: true
|
||||
kl_loss_coef: 0.001
|
||||
@ -315,7 +315,8 @@ critic:
|
||||
clip_grad: 1.0
|
||||
min_lr_ratio: 0.0
|
||||
num_cycles: 0.5
|
||||
warmup_style: constant
|
||||
lr_scheduler_type: constant
|
||||
warmup_style: null
|
||||
model:
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
@ -462,6 +463,12 @@ algorithm:
|
||||
pf_ppo:
|
||||
reweight_method: pow
|
||||
weight_pow: 2.0
|
||||
rollout_is_threshold: null
|
||||
rollout_is_threshold_lower: null
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 0.0001
|
||||
rollout_is: false
|
||||
trainer:
|
||||
balance_batch: true
|
||||
total_epochs: 30
|
||||
|
@ -74,10 +74,6 @@ loss_agg_mode: token-mean
|
||||
# Entropy regularization coefficient in PPO loss
|
||||
entropy_coeff: 0
|
||||
|
||||
# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
|
||||
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
|
||||
tis_imp_ratio_cap: -1
|
||||
|
||||
# Whether to use KL loss instead of KL reward penalty. True for GRPO
|
||||
use_kl_loss: false
|
||||
|
||||
|
@ -73,6 +73,14 @@ class AlgoConfig(BaseConfig):
|
||||
use_pf_ppo (bool): Whether to enable preference feedback PPO.
|
||||
pf_ppo (dict[str, Any]): Preference feedback PPO settings.
|
||||
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
|
||||
rollout_is_threshold (Optional[float]): Upper threshold for IS weights. null = disabled,
|
||||
float value = enabled (compute weights and metrics). This is the main on/off switch.
|
||||
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
|
||||
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
|
||||
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds).
|
||||
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
|
||||
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
|
||||
False = compute metrics only (useful for monitoring before enabling correction). Default: False.
|
||||
"""
|
||||
|
||||
gamma: float = 1.0
|
||||
@ -85,3 +93,13 @@ class AlgoConfig(BaseConfig):
|
||||
use_pf_ppo: bool = False
|
||||
pf_ppo: dict[str, Any] = field(default_factory=dict)
|
||||
filter_groups: Optional[FilterGroupsConfig] = None
|
||||
# Rollout Importance Sampling (replaces legacy tis_imp_ratio_cap)
|
||||
# Controls computation of IS weights and mismatch metrics
|
||||
rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled
|
||||
rollout_is_threshold_lower: Optional[float] = None
|
||||
rollout_is_level: str = "token"
|
||||
rollout_is_mode: str = "truncate"
|
||||
rollout_is_veto_threshold: Optional[float] = 1e-4
|
||||
# Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set)
|
||||
# True = apply weights to loss, False = compute metrics only (no weight application)
|
||||
rollout_is: bool = False
|
||||
|
@ -28,6 +28,8 @@ min_lr_ratio: 0.0
|
||||
# Number of cosine cycles in LR schedule
|
||||
num_cycles: 0.5
|
||||
|
||||
# LR warmup style: "constant" or "cosine"
|
||||
warmup_style: constant
|
||||
# LR scheduler type: "constant" or "cosine"
|
||||
lr_scheduler_type: constant
|
||||
|
||||
# deprecated
|
||||
warmup_style: null
|
||||
|
@ -73,6 +73,28 @@ algorithm:
|
||||
reweight_method: pow # ["pow", "max_min", "max_random"]
|
||||
weight_pow: 2.0
|
||||
|
||||
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
|
||||
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
|
||||
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
|
||||
rollout_is_threshold: null
|
||||
|
||||
# Lower threshold for IS weights (null = auto-reciprocal of upper)
|
||||
rollout_is_threshold_lower: null
|
||||
|
||||
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
|
||||
rollout_is_level: token
|
||||
|
||||
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# Per-token veto threshold for catastrophic outliers
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
# Whether to apply IS weights to policy loss
|
||||
# true = apply weights to loss, false = compute metrics only (no weight application)
|
||||
# Useful for monitoring mismatch before enabling correction
|
||||
rollout_is: false
|
||||
|
||||
trainer:
|
||||
balance_batch: True
|
||||
total_epochs: 30
|
||||
|
@ -113,6 +113,28 @@ algorithm:
|
||||
# Power used for weight scaling in "pow" method
|
||||
weight_pow: 2.0
|
||||
|
||||
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
|
||||
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
|
||||
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
|
||||
rollout_is_threshold: null
|
||||
|
||||
# Lower threshold for IS weights (null = auto-reciprocal of upper)
|
||||
rollout_is_threshold_lower: null
|
||||
|
||||
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
|
||||
rollout_is_level: token
|
||||
|
||||
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# Per-token veto threshold for catastrophic outliers
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
# Whether to apply IS weights to policy loss
|
||||
# true = apply weights to loss, false = compute metrics only (no weight application)
|
||||
# Useful for monitoring mismatch before enabling correction
|
||||
rollout_is: false
|
||||
|
||||
# config for the trainer
|
||||
trainer:
|
||||
|
||||
|
@ -18,7 +18,7 @@ data:
|
||||
max_token_len_per_gpu: 8192
|
||||
use_dynamic_bsz: True
|
||||
train_files: ~/data/gsm8k/train.parquet
|
||||
val_files: ~/data/gsm8k/test.parquet
|
||||
val_files: null
|
||||
# Multi-turn settings
|
||||
messages_key: messages # Key for messages list in multi-turn mode
|
||||
tools_key: tools # Key for tools list in multi-turn mode
|
||||
|
@ -31,7 +31,8 @@ from verl.utils.fs import copy_to_local
|
||||
|
||||
|
||||
@ray.remote
|
||||
def process_item(reward_fn, data_source, response_lst, reward_data):
|
||||
def process_item(config, data_source, response_lst, reward_data):
|
||||
reward_fn = get_custom_reward_fn(config)
|
||||
ground_truth = reward_data["ground_truth"]
|
||||
score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
|
||||
return data_source, np.mean(score_lst)
|
||||
@ -53,11 +54,9 @@ def main(config):
|
||||
|
||||
# evaluate test_score based on data source
|
||||
data_source_reward = defaultdict(list)
|
||||
compute_score = get_custom_reward_fn(config)
|
||||
|
||||
# Create remote tasks
|
||||
remote_tasks = [
|
||||
process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
|
||||
process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
|
||||
]
|
||||
|
||||
# Process results as they come in
|
||||
|
@ -17,6 +17,7 @@ Generate responses given a dataset of prompts
|
||||
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
import hydra
|
||||
import numpy as np
|
||||
import ray
|
||||
@ -30,31 +31,12 @@ from pprint import pprint
|
||||
|
||||
import pandas as pd
|
||||
from omegaconf import OmegaConf
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from verl.utils.hdfs_io import makedirs
|
||||
from verl.workers.rollout.replica import get_rollout_replica_class
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
|
||||
def main(config):
|
||||
run_generation(config)
|
||||
|
||||
|
||||
def run_generation(config) -> None:
|
||||
if not ray.is_initialized():
|
||||
# this is for local ray cluster
|
||||
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1"}}
|
||||
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
|
||||
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
|
||||
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
|
||||
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
|
||||
print(f"ray init kwargs: {ray_init_kwargs}")
|
||||
ray.init(**OmegaConf.to_container(ray_init_kwargs))
|
||||
|
||||
ray.get(main_task.remote(config))
|
||||
|
||||
|
||||
async def start_server(config):
|
||||
tp_size = config.actor_rollout_ref.rollout.tensor_model_parallel_size
|
||||
num_replicas = (config.trainer.n_gpus_per_node * config.trainer.nnodes) // tp_size
|
||||
@ -81,23 +63,42 @@ async def start_server(config):
|
||||
return server_handles, server_addresses
|
||||
|
||||
|
||||
async def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list):
|
||||
# here we should sample n_samples for each chat_lst
|
||||
client = AsyncOpenAI(
|
||||
api_key="123-abc",
|
||||
base_url=f"http://{server_address}/v1",
|
||||
)
|
||||
async def submit_request(server_address, **chat_complete_request):
|
||||
try:
|
||||
extra_headers = chat_complete_request.pop("extra_headers", {})
|
||||
timeout = aiohttp.ClientTimeout(total=None)
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
async with session.post(
|
||||
url=f"http://{server_address}/v1/chat/completions",
|
||||
headers={"Authorization": "Bearer token-abc123", **extra_headers},
|
||||
json=chat_complete_request,
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
return ChatCompletion(**data)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
tasks = [
|
||||
client.chat.completions.create(
|
||||
model=model_path,
|
||||
messages=messages,
|
||||
|
||||
async def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list):
|
||||
# here we should sample n_samples for each chat_lst.
|
||||
# we use aiohttp to avoid hang in AsyncOpenAI when the number of requests is large.
|
||||
|
||||
# client = AsyncOpenAI(
|
||||
# api_key="123-abc",
|
||||
# base_url=f"http://{server_address}/v1",
|
||||
# )
|
||||
|
||||
chat_complete_request = [
|
||||
{
|
||||
"model": model_path,
|
||||
"messages": messages,
|
||||
**sampling_params,
|
||||
)
|
||||
}
|
||||
for messages in chat_lst
|
||||
for _ in range(n_samples)
|
||||
]
|
||||
|
||||
tasks = [submit_request(server_address, **req) for req in chat_complete_request]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
@ -118,8 +119,10 @@ async def generate(
|
||||
return results
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def main_task(config):
|
||||
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
|
||||
def main(config):
|
||||
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1"}})
|
||||
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
@ -136,8 +139,21 @@ def main_task(config):
|
||||
"max_tokens": config.actor_rollout_ref.rollout.response_length,
|
||||
}
|
||||
|
||||
from omegaconf import ListConfig
|
||||
|
||||
train_files = config.data.train_files
|
||||
if not isinstance(train_files, list | ListConfig):
|
||||
train_files = [train_files]
|
||||
|
||||
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
|
||||
dataset = pd.read_parquet(config.data.train_files)
|
||||
|
||||
datasets = []
|
||||
for train_file in train_files:
|
||||
dataset = pd.read_parquet(train_file)
|
||||
datasets.append(dataset)
|
||||
|
||||
# concat dataset
|
||||
dataset = pd.concat(datasets, axis=0, ignore_index=True)
|
||||
chat_lst = dataset[config.data.prompt_key].tolist()
|
||||
chat_lst = [chat.tolist() for chat in chat_lst]
|
||||
chat_numpy = np.array(chat_lst)
|
||||
@ -151,7 +167,6 @@ def main_task(config):
|
||||
)
|
||||
|
||||
# reshape results into a numpy array
|
||||
|
||||
import itertools
|
||||
|
||||
results = list(itertools.chain.from_iterable(gen_results))
|
||||
@ -170,6 +185,7 @@ def main_task(config):
|
||||
# write to a new parquet
|
||||
output_dir = os.path.dirname(config.data.output_path)
|
||||
makedirs(output_dir, exist_ok=True)
|
||||
print(f"Saving results to {config.data.output_path}")
|
||||
dataset.to_parquet(config.data.output_path)
|
||||
|
||||
|
||||
|
@ -881,7 +881,7 @@ def compute_policy_loss(
|
||||
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
|
||||
|
||||
|
||||
@register_policy_loss("vanilla")
|
||||
@register_policy_loss("vanilla") # type: ignore[arg-type]
|
||||
def compute_policy_loss_vanilla(
|
||||
old_log_prob: torch.Tensor,
|
||||
log_prob: torch.Tensor,
|
||||
@ -889,7 +889,7 @@ def compute_policy_loss_vanilla(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for PPO.
|
||||
@ -959,11 +959,9 @@ def compute_policy_loss_vanilla(
|
||||
|
||||
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
|
||||
|
||||
if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
|
||||
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
|
||||
tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
|
||||
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
|
||||
pg_losses = pg_losses * tis_imp_ratio
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
@ -978,7 +976,7 @@ def compute_policy_loss_gspo(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "seq-mean-token-mean",
|
||||
config: Optional[DictConfig | ActorConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for GSPO.
|
||||
@ -1024,6 +1022,10 @@ def compute_policy_loss_gspo(
|
||||
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
|
||||
pg_losses = torch.maximum(pg_losses1, pg_losses2)
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
# for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")
|
||||
|
||||
@ -1044,7 +1046,7 @@ def compute_policy_loss_gpg(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Adapted from
|
||||
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
|
||||
@ -1061,6 +1063,10 @@ def compute_policy_loss_gpg(
|
||||
"""
|
||||
pg_losses = -log_prob * advantages
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
|
||||
|
||||
@ -1073,7 +1079,7 @@ def compute_policy_loss_clip_cov(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
@ -1155,6 +1161,11 @@ def compute_policy_loss_clip_cov(
|
||||
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
|
||||
|
||||
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
|
||||
@ -1168,7 +1179,7 @@ def compute_policy_loss_kl_cov(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
@ -1227,6 +1238,10 @@ def compute_policy_loss_kl_cov(
|
||||
large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]
|
||||
]
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
|
||||
@ -1240,7 +1255,7 @@ def compute_policy_loss_geo_mean(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for GMPO.
|
||||
@ -1293,6 +1308,17 @@ def compute_policy_loss_geo_mean(
|
||||
# otherwise, below would be not consistent with the paper
|
||||
advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
|
||||
pg_losses = -advantage * ratio
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
# For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level
|
||||
if rollout_is_weights is not None:
|
||||
# Aggregate token-level weights to sequence level using geometric mean for consistency
|
||||
# Note: rollout_is_weights is always 2D regardless of rollout_is_level
|
||||
seq_is_weights = torch.exp(
|
||||
(torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
|
||||
)
|
||||
pg_losses = pg_losses * seq_is_weights
|
||||
|
||||
pg_loss = torch.mean(pg_losses)
|
||||
|
||||
# higher: ratio is too large that need clamp to clip_high (when adv > 0)
|
||||
|
459
verl/trainer/ppo/mismatch_helper.py
Normal file
459
verl/trainer/ppo/mismatch_helper.py
Normal file
@ -0,0 +1,459 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Rollout Importance Sampling (IS) Helper Module
|
||||
|
||||
This module handles importance sampling weight computation for correcting
|
||||
distribution mismatch between rollout policy (e.g., vLLM BFloat16) and
|
||||
training policy (e.g., FSDP FP32).
|
||||
|
||||
Key Features:
|
||||
1. Three aggregation levels: token, sequence, geometric
|
||||
2. Two handling modes: truncate (TIS), mask (MIS)
|
||||
3. Per-token veto mechanism for catastrophic outliers
|
||||
4. Memory-efficient computation to prevent CUDA OOM
|
||||
5. Comprehensive metrics tracking
|
||||
|
||||
Usage Notes:
|
||||
- compute_rollout_importance_weights() computes both IS weights and mismatch metrics
|
||||
- Used in ray_trainer.py via compute_rollout_importance_weights_and_add_to_batch()
|
||||
- Also used in dp_actor.py for distributed worker computations
|
||||
- compute_mismatch_metrics() is called internally by compute_rollout_importance_weights()
|
||||
|
||||
References:
|
||||
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from verl.protocol import DataProto
|
||||
|
||||
|
||||
def compute_rollout_importance_weights(
|
||||
old_log_prob: torch.Tensor,
|
||||
rollout_log_prob: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
rollout_is_level: str = "token",
|
||||
rollout_is_mode: str = "truncate",
|
||||
rollout_is_threshold: Optional[float] = None,
|
||||
rollout_is_threshold_lower: Optional[float] = None,
|
||||
rollout_is_veto_threshold: Optional[float] = 1e-4,
|
||||
) -> tuple[Optional[DataProto], dict[str, Any]]:
|
||||
"""Compute importance sampling weights and metrics for rollout-training mismatch correction.
|
||||
|
||||
This function handles the computation of importance sampling (IS) weights to correct
|
||||
for the distribution mismatch between rollout policy and training policy.
|
||||
|
||||
Reference:
|
||||
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
|
||||
Memory-efficient implementation that prevents CUDA OOM by:
|
||||
- Using log-space computation where possible
|
||||
- Applying safety bounds to prevent numerical overflow
|
||||
- Computing metrics without creating huge intermediate tensors
|
||||
|
||||
Args:
|
||||
old_log_prob: Log probabilities from training policy (e.g., FSDP), shape (batch_size, seq_length)
|
||||
rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM), shape (batch_size, seq_length)
|
||||
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
|
||||
rollout_is_level: Level of IS aggregation:
|
||||
- "token": Per-token ratios (biased)
|
||||
- "sequence": Product of ratios (unbiased)
|
||||
- "geometric": Geometric mean of ratios (experimental)
|
||||
rollout_is_mode: How to handle weights exceeding threshold:
|
||||
- "truncate": Cap weights at upper_threshold only (TIS)
|
||||
- "mask": Zero out weights outside [lower_threshold, upper_threshold] (MIS)
|
||||
rollout_is_threshold: Upper threshold for IS weights
|
||||
rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper)
|
||||
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
|
||||
If None, veto mechanism is disabled.
|
||||
|
||||
Returns:
|
||||
Tuple of (weights_proto, metrics) where:
|
||||
weights_proto: DataProto containing IS weights with key "rollout_is_weights",
|
||||
shape (batch_size, seq_length). Returns None if rollout_is_threshold is None.
|
||||
metrics: Dictionary of IS statistics and mismatch metrics (KL, PPL, etc.),
|
||||
all converted to scalars and prefixed with "mismatch/"
|
||||
"""
|
||||
if rollout_is_threshold is None:
|
||||
return None, {}
|
||||
|
||||
# Parse thresholds: if lower not specified, use 1/upper (reciprocal)
|
||||
upper_threshold = rollout_is_threshold
|
||||
if rollout_is_threshold_lower is not None:
|
||||
lower_threshold = rollout_is_threshold_lower
|
||||
else:
|
||||
# Default: lower = 1/upper (reciprocal)
|
||||
lower_threshold = 1.0 / upper_threshold
|
||||
|
||||
# Step 1: Compute raw importance weights based on the specified level
|
||||
log_ratio = old_log_prob - rollout_log_prob
|
||||
|
||||
# Pre-compute log thresholds
|
||||
device = old_log_prob.device
|
||||
log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device))
|
||||
log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device))
|
||||
|
||||
# Safety bound to prevent numerical overflow (exp(20) ≈ 485M)
|
||||
SAFETY_BOUND = 20.0
|
||||
|
||||
# Store unclamped values in log-space for accurate metrics
|
||||
if rollout_is_level == "token":
|
||||
# Token-level IS: π_train(a|s) / π_rollout(a|s) per token
|
||||
log_ratio_for_metrics = log_ratio
|
||||
|
||||
# Apply safety bound to prevent overflow
|
||||
log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND)
|
||||
rollout_is_weights = torch.exp(log_ratio_safe)
|
||||
|
||||
elif rollout_is_level == "sequence":
|
||||
# Sequence-level IS: π_train(y|x) / π_rollout(y|x) for entire sequence
|
||||
# Product of token ratios: exp(Σ log(π_train/π_rollout))
|
||||
log_ratio_sum = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze(-1)
|
||||
log_ratio_for_metrics = log_ratio_sum # Store for metrics
|
||||
|
||||
# Apply safety bound to prevent overflow
|
||||
log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND)
|
||||
rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob)
|
||||
|
||||
elif rollout_is_level == "geometric":
|
||||
# Geometric mean IS: (∏ π_train/π_rollout)^(1/T)
|
||||
# Equivalent to exp(mean(log(π_train/π_rollout)))
|
||||
log_ratio_mean = verl_F.masked_mean(log_ratio, response_mask, axis=-1).unsqueeze(-1)
|
||||
log_ratio_for_metrics = log_ratio_mean # Store for metrics
|
||||
|
||||
# Geometric mean rarely explodes due to averaging, but apply safety bound anyway
|
||||
log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND)
|
||||
rollout_is_weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid rollout_is_level: {rollout_is_level}. Must be 'token', 'sequence', or 'geometric'.")
|
||||
|
||||
# Step 1.5: Apply per-token veto check in log space (memory efficient)
|
||||
if rollout_is_veto_threshold is not None:
|
||||
log_veto_threshold = torch.log(torch.tensor(rollout_is_veto_threshold, device=device))
|
||||
|
||||
# Check if any token ratio is below veto threshold (in log space)
|
||||
# log(π_train/π_rollout) < log(veto_threshold) ⟺ π_train/π_rollout < veto_threshold
|
||||
catastrophic_tokens = (log_ratio < log_veto_threshold) & response_mask.bool()
|
||||
|
||||
# For each sequence, check if it has any catastrophic token
|
||||
# Use broadcasting instead of expand_as to save memory
|
||||
has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True)
|
||||
|
||||
# Create veto mask: 0 if sequence has catastrophic token, 1 otherwise
|
||||
veto_mask = (~has_catastrophic).float()
|
||||
else:
|
||||
# No veto mechanism
|
||||
catastrophic_tokens = torch.zeros_like(response_mask, dtype=torch.bool)
|
||||
has_catastrophic = torch.zeros((old_log_prob.size(0), 1), dtype=torch.bool, device=device)
|
||||
veto_mask = torch.ones((old_log_prob.size(0), 1), dtype=torch.float32, device=device)
|
||||
|
||||
# Step 2: Compute comprehensive metrics
|
||||
metrics = compute_is_metrics(
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
log_ratio_for_metrics=log_ratio_for_metrics,
|
||||
response_mask=response_mask,
|
||||
rollout_is_level=rollout_is_level,
|
||||
rollout_is_threshold=upper_threshold,
|
||||
rollout_is_threshold_lower=lower_threshold,
|
||||
log_threshold_upper=log_threshold_upper,
|
||||
log_threshold_lower=log_threshold_lower,
|
||||
has_catastrophic=has_catastrophic,
|
||||
catastrophic_tokens=catastrophic_tokens,
|
||||
SAFETY_BOUND=SAFETY_BOUND,
|
||||
)
|
||||
|
||||
# Step 3: Apply truncation or masking based on mode
|
||||
if rollout_is_mode == "truncate":
|
||||
# Truncated IS (TIS): only cap upper bound to prevent overweighting
|
||||
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)
|
||||
|
||||
elif rollout_is_mode == "mask":
|
||||
# Masked IS (MIS): zero out weights outside [lower_threshold, upper_threshold]
|
||||
mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
|
||||
mask = mask.float()
|
||||
|
||||
# Track MIS-specific metrics
|
||||
metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask)
|
||||
|
||||
# Sequence-level masking fraction
|
||||
if rollout_is_level in ["sequence", "geometric"]:
|
||||
# All tokens in a sequence have the same weight, so reuse mask
|
||||
metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean()
|
||||
else:
|
||||
# Check if any token in each sequence is masked
|
||||
seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0
|
||||
metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean()
|
||||
|
||||
rollout_is_weights = rollout_is_weights * mask
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.")
|
||||
|
||||
# Apply veto mask AFTER all thresholding
|
||||
# This zeros out entire sequences that have any catastrophic token
|
||||
rollout_is_weights = rollout_is_weights * veto_mask
|
||||
|
||||
# Apply response_mask to ensure weights are 0 where mask is 0
|
||||
rollout_is_weights = rollout_is_weights * response_mask
|
||||
|
||||
# Wrap in DataProto for consistency with worker methods
|
||||
rollout_is_weights_proto = DataProto.from_dict(tensors={"rollout_is_weights": rollout_is_weights})
|
||||
|
||||
# Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics
|
||||
mismatch_metrics = compute_mismatch_metrics(
|
||||
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
|
||||
)
|
||||
metrics.update(mismatch_metrics)
|
||||
|
||||
# Convert all tensor metrics to scalars for logging
|
||||
# Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad()
|
||||
metrics_scalar = {}
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
metrics_scalar[f"mismatch/{key}"] = value.item()
|
||||
else:
|
||||
metrics_scalar[f"mismatch/{key}"] = value
|
||||
|
||||
return rollout_is_weights_proto, metrics_scalar
|
||||
|
||||
|
||||
def compute_is_metrics(
|
||||
rollout_is_weights: torch.Tensor,
|
||||
log_ratio_for_metrics: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
rollout_is_level: str,
|
||||
rollout_is_threshold: float,
|
||||
rollout_is_threshold_lower: float,
|
||||
log_threshold_upper: torch.Tensor,
|
||||
log_threshold_lower: torch.Tensor,
|
||||
has_catastrophic: torch.Tensor,
|
||||
catastrophic_tokens: torch.Tensor,
|
||||
SAFETY_BOUND: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute comprehensive metrics for importance sampling weights.
|
||||
|
||||
Reference:
|
||||
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
|
||||
This function computes metrics using a mix of true unclamped values (for max/min/fractions
|
||||
in sequence/geometric mode via log-space) and safety-clamped values (for mean/std/ESS)
|
||||
to balance accuracy with numerical stability and avoid overflow.
|
||||
"""
|
||||
# Validate that we have at least one valid sample
|
||||
assert response_mask.any(), "Expected at least one valid sample in response_mask"
|
||||
|
||||
metrics = {}
|
||||
device = rollout_is_weights.device
|
||||
|
||||
# Track veto statistics
|
||||
metrics["rollout_is_veto_fraction"] = has_catastrophic.float().mean()
|
||||
metrics["rollout_is_catastrophic_token_fraction"] = verl_F.masked_mean(catastrophic_tokens.float(), response_mask)
|
||||
|
||||
# Compute metrics based on IS level
|
||||
if rollout_is_level in ["sequence", "geometric"]:
|
||||
# For sequence/geometric, compute true statistics from log-space
|
||||
# This reflects the actual distribution before clamping
|
||||
|
||||
# True max/min in log space
|
||||
log_max = log_ratio_for_metrics.max()
|
||||
log_min = log_ratio_for_metrics.min()
|
||||
|
||||
# Convert to regular space with safety bound
|
||||
metrics["rollout_is_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND))
|
||||
metrics["rollout_is_min"] = torch.exp(log_min)
|
||||
|
||||
# Mean uses clamped weights to avoid overflow
|
||||
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
|
||||
|
||||
# Compute fraction exceeding threshold in log space (accurate)
|
||||
exceeds_upper = log_ratio_for_metrics > log_threshold_upper
|
||||
below_lower = log_ratio_for_metrics < log_threshold_lower
|
||||
|
||||
if rollout_is_level == "sequence":
|
||||
# For sequence level, all tokens in a sequence have the same weight
|
||||
metrics["rollout_is_ratio_fraction_high"] = exceeds_upper.float().mean()
|
||||
metrics["rollout_is_ratio_fraction_low"] = below_lower.float().mean()
|
||||
else: # geometric
|
||||
# Need to expand to match token dimensions
|
||||
exceeds_upper_expanded = exceeds_upper.expand_as(response_mask)
|
||||
below_lower_expanded = below_lower.expand_as(response_mask)
|
||||
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
|
||||
exceeds_upper_expanded.float(), response_mask
|
||||
)
|
||||
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(below_lower_expanded.float(), response_mask)
|
||||
|
||||
else:
|
||||
# Token-level: compute directly from weights
|
||||
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
|
||||
|
||||
# Fraction exceeding thresholds
|
||||
rollout_is_above_threshold = rollout_is_weights > rollout_is_threshold
|
||||
rollout_is_below_threshold = rollout_is_weights < rollout_is_threshold_lower
|
||||
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
|
||||
rollout_is_above_threshold.float(), response_mask
|
||||
)
|
||||
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(rollout_is_below_threshold.float(), response_mask)
|
||||
|
||||
# Max/min for token level
|
||||
mask_bool = response_mask.bool()
|
||||
metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max()
|
||||
metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min()
|
||||
|
||||
# Compute standard deviation using clamped weights to avoid overflow
|
||||
mask_count = response_mask.sum()
|
||||
if mask_count > 1:
|
||||
# Use clamped weights for variance to avoid squaring huge values
|
||||
weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
|
||||
# Use mean from clamped weights for consistency
|
||||
mean_clamped = verl_F.masked_mean(weights_for_std, response_mask)
|
||||
rollout_is_var = verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square()
|
||||
metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0))
|
||||
else:
|
||||
metrics["rollout_is_std"] = torch.tensor(0.0, device=device)
|
||||
|
||||
# Effective sample size (use clamped weights to avoid overflow)
|
||||
weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
|
||||
mean_for_ess = verl_F.masked_mean(weights_for_ess, response_mask)
|
||||
is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8)
|
||||
metrics["rollout_is_eff_sample_size"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask)
|
||||
|
||||
# Per-sequence breakdown metrics
|
||||
if rollout_is_weights.dim() > 1:
|
||||
# Compute mean IS weight per sequence
|
||||
seq_mean_weights = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1)
|
||||
|
||||
# Per-sequence statistics
|
||||
metrics["rollout_is_seq_mean"] = seq_mean_weights.mean()
|
||||
metrics["rollout_is_seq_std"] = (
|
||||
seq_mean_weights.std() if seq_mean_weights.numel() > 1 else torch.tensor(0.0, device=device)
|
||||
)
|
||||
metrics["rollout_is_seq_max"] = seq_mean_weights.max()
|
||||
metrics["rollout_is_seq_min"] = seq_mean_weights.min()
|
||||
|
||||
# Identify most problematic sequences
|
||||
seq_deviation = (seq_mean_weights - 1.0).abs()
|
||||
metrics["rollout_is_seq_max_deviation"] = seq_deviation.max()
|
||||
|
||||
# Fraction of sequences with high IS weights
|
||||
metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean()
|
||||
metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean()
|
||||
|
||||
# Percentile metrics for better distribution understanding
|
||||
# Get all valid IS weights
|
||||
flat_weights = rollout_is_weights[response_mask.bool()]
|
||||
# Compute key percentiles (guaranteed to have elements due to assertion at function start)
|
||||
assert flat_weights.numel() > 0, "flat_weights should not be empty"
|
||||
metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25)
|
||||
metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) # median
|
||||
metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75)
|
||||
metrics["rollout_is_p95"] = torch.quantile(flat_weights, 0.95)
|
||||
metrics["rollout_is_p99"] = torch.quantile(flat_weights, 0.99)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def compute_mismatch_metrics(
|
||||
old_log_prob: torch.Tensor,
|
||||
rollout_log_prob: Optional[torch.Tensor],
|
||||
response_mask: torch.Tensor,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute training-inference mismatch metrics (helper function).
|
||||
|
||||
This helper function operates on raw tensors and is used internally by:
|
||||
- compute_rollout_importance_weights() in this module (automatically included)
|
||||
- Tests (test_rollout_is.py, test_rollout_is_integration.py)
|
||||
|
||||
These metrics help diagnose the mismatch between the rollout policy (e.g., vLLM)
|
||||
and the training policy (e.g., FSDP), which can cause training instability.
|
||||
|
||||
Key metrics:
|
||||
- mismatch_kl: Direct KL divergence estimator KL(π_rollout || π_training)
|
||||
- mismatch_k3_kl: K3 KL estimator for stability (more stable for small KL)
|
||||
- training_ppl: Perplexity of training policy
|
||||
- rollout_ppl: Perplexity of rollout policy
|
||||
- log_ppl_diff: Difference in log perplexities
|
||||
- ppl_ratio: Ratio of training PPL to rollout PPL
|
||||
|
||||
Args:
|
||||
old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length)
|
||||
rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length)
|
||||
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
Dictionary of mismatch metrics (without prefix)
|
||||
|
||||
Reference:
|
||||
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
"""
|
||||
# Validate that we have at least one valid token
|
||||
assert response_mask.any(), "Expected at least one valid token in response_mask"
|
||||
|
||||
metrics = {}
|
||||
|
||||
# 1. Training policy perplexity (always available)
|
||||
# Formula: exp(-1/|T| * Σ log π_training(y_t|y_<t))
|
||||
# where |T| is the number of tokens generated by the model
|
||||
mean_log_prob_training = verl_F.masked_mean(old_log_prob, response_mask, axis=-1) # (batch_size,)
|
||||
training_ppl = torch.exp(-mean_log_prob_training).mean() # Batch mean of per-sequence PPL
|
||||
metrics["mismatch_training_ppl"] = training_ppl.detach().item()
|
||||
|
||||
# Also log log-ppl for easier analysis (avoids exponential scale)
|
||||
metrics["mismatch_training_log_ppl"] = (-mean_log_prob_training).mean().detach().item()
|
||||
|
||||
# 2. Compute rollout mismatch metrics (only if rollout_log_probs available)
|
||||
if rollout_log_prob is not None:
|
||||
# 2a. mismatch_kl: Direct estimator for KL(π_rollout || π_training)
|
||||
# This is the standard KL divergence: E[log(π_rollout) - log(π_training)]
|
||||
# Positive value means rollout policy is more confident than training policy
|
||||
metrics["mismatch_kl"] = verl_F.masked_mean(rollout_log_prob - old_log_prob, response_mask).detach().item()
|
||||
|
||||
# 2b. mismatch_k3_kl: K3 estimator for KL(π_rollout || π_training)
|
||||
# More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1]
|
||||
# Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout
|
||||
log_ratio = old_log_prob - rollout_log_prob
|
||||
mismatch_k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
|
||||
metrics["mismatch_k3_kl"] = verl_F.masked_mean(mismatch_k3_kl_matrix, response_mask).detach().item()
|
||||
|
||||
# 2c. Rollout policy perplexity
|
||||
mean_log_prob_rollout = verl_F.masked_mean(rollout_log_prob, response_mask, axis=-1) # (batch_size,)
|
||||
rollout_ppl = torch.exp(-mean_log_prob_rollout).mean() # Batch mean of per-sequence PPL
|
||||
metrics["mismatch_rollout_ppl"] = rollout_ppl.detach().item()
|
||||
metrics["mismatch_rollout_log_ppl"] = (-mean_log_prob_rollout).mean().detach().item()
|
||||
|
||||
# 2d. Log PPL difference (sequence-level perplexity difference)
|
||||
# log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
|
||||
# Since ppl = exp(-log_prob), we have:
|
||||
# log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff
|
||||
# Positive value means training assigns lower probability (higher PPL) than rollout
|
||||
log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
|
||||
metrics["mismatch_log_ppl_diff"] = log_ppl_diff.mean().detach().item()
|
||||
metrics["mismatch_log_ppl_abs_diff"] = log_ppl_diff.abs().mean().detach().item()
|
||||
metrics["mismatch_log_ppl_diff_max"] = log_ppl_diff.max().detach().item()
|
||||
metrics["mismatch_log_ppl_diff_min"] = log_ppl_diff.min().detach().item()
|
||||
|
||||
# 2e. PPL ratio (how much higher is training PPL vs rollout PPL)
|
||||
# IMPORTANT: Compute per-sequence ratio first, then average
|
||||
# For numerical stability, compute in log space using log_ppl_diff
|
||||
# Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff)
|
||||
# This is the inverse of geometric IS: ppl_ratio_i = 1 / geometric_is_i for each sequence
|
||||
ppl_ratio = torch.exp(log_ppl_diff).mean() # mean(exp(log_ppl_diff)) = mean(ppl_ratio_i)
|
||||
metrics["mismatch_ppl_ratio"] = ppl_ratio.detach().item()
|
||||
|
||||
return metrics
|
@ -49,6 +49,7 @@ from verl.trainer.ppo.metric_utils import (
|
||||
compute_timing_metrics,
|
||||
process_validation_metrics,
|
||||
)
|
||||
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
|
||||
@ -918,6 +919,49 @@ class RayPPOTrainer:
|
||||
)
|
||||
metrics.update(global_balance_stats)
|
||||
|
||||
def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) -> tuple[DataProto, dict]:
|
||||
"""Compute rollout importance sampling weights and mismatch metrics, conditionally add weights to batch.
|
||||
|
||||
This method computes IS weights to correct for distribution mismatch between
|
||||
rollout policy and training policy. It always computes metrics when enabled, but
|
||||
only adds weights to batch if algorithm.rollout_is is True.
|
||||
|
||||
Args:
|
||||
batch: DataProto containing old_log_probs, rollout_log_probs, response_mask
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_batch, metrics) where:
|
||||
- updated_batch: Batch with rollout_is_weights added (if rollout_is=True)
|
||||
- metrics: Dictionary of IS and mismatch metrics (all with mismatch/ prefix)
|
||||
"""
|
||||
# Compute rollout IS weights if enabled and data is available
|
||||
# rollout_is_threshold is the main on/off switch
|
||||
if self.config.algorithm.rollout_is_threshold is not None and "rollout_log_probs" in batch.batch:
|
||||
rollout_is_weights, rollout_is_metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=batch.batch["old_log_probs"],
|
||||
rollout_log_prob=batch.batch["rollout_log_probs"],
|
||||
response_mask=batch.batch["response_mask"],
|
||||
rollout_is_level=self.config.algorithm.rollout_is_level,
|
||||
rollout_is_mode=self.config.algorithm.rollout_is_mode,
|
||||
rollout_is_threshold=self.config.algorithm.rollout_is_threshold,
|
||||
rollout_is_threshold_lower=self.config.algorithm.rollout_is_threshold_lower,
|
||||
rollout_is_veto_threshold=self.config.algorithm.rollout_is_veto_threshold,
|
||||
)
|
||||
|
||||
# Control: Should we apply weights to policy loss?
|
||||
# True = add weights to batch (actor will apply them)
|
||||
# False = don't add weights (metrics only, no loss modification)
|
||||
apply_weights = self.config.algorithm.get("rollout_is", False)
|
||||
|
||||
if apply_weights:
|
||||
# Add IS weights to batch for distribution to workers
|
||||
batch = batch.union(rollout_is_weights)
|
||||
|
||||
return batch, rollout_is_metrics
|
||||
|
||||
# Return unchanged batch and empty metrics if IS is disabled
|
||||
return batch, {}
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
The training loop of PPO.
|
||||
@ -1107,6 +1151,13 @@ class RayPPOTrainer:
|
||||
else:
|
||||
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
||||
|
||||
# Compute rollout importance sampling weights centrally (once per batch)
|
||||
# This corrects for mismatch between rollout policy and training policy
|
||||
# Also computes mismatch metrics (KL, PPL, etc.)
|
||||
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
|
||||
# IS and mismatch metrics already have mismatch/ prefix
|
||||
metrics.update(is_metrics)
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
||||
"norm_adv_by_std_in_grpo", True
|
||||
@ -1205,6 +1256,7 @@ class RayPPOTrainer:
|
||||
# TODO: implement actual tflpo and theoretical tflpo
|
||||
n_gpus = self.resource_pool_manager.get_n_gpus()
|
||||
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
||||
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
|
||||
|
||||
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
|
||||
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
|
||||
|
@ -146,7 +146,10 @@ class SFTTrainer:
|
||||
config = self.config
|
||||
tokenizer = self.model_config.tokenizer
|
||||
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
|
||||
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
|
||||
if config.data.val_files:
|
||||
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
|
||||
else:
|
||||
val_dataset = None
|
||||
|
||||
self.train_dataset, self.val_dataset = train_dataset, val_dataset
|
||||
|
||||
@ -181,19 +184,22 @@ class SFTTrainer:
|
||||
pin_memory_device=device_name,
|
||||
)
|
||||
|
||||
self.val_sampler = DistributedSampler(
|
||||
self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True
|
||||
)
|
||||
self.val_dataloader = StatefulDataLoader(
|
||||
dataset=self.val_dataset,
|
||||
batch_size=self.train_batch_size_per_dp,
|
||||
sampler=self.val_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
pin_memory_device=device_name,
|
||||
)
|
||||
if self.val_dataset:
|
||||
self.val_sampler = DistributedSampler(
|
||||
self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True
|
||||
)
|
||||
self.val_dataloader = StatefulDataLoader(
|
||||
dataset=self.val_dataset,
|
||||
batch_size=self.train_batch_size_per_dp,
|
||||
sampler=self.val_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
pin_memory_device=device_name,
|
||||
)
|
||||
else:
|
||||
self.val_dataloader = None
|
||||
|
||||
def fit(self):
|
||||
is_logging = self.engine.is_mp_src_rank_with_outputs() and self.engine.get_data_parallel_rank() == 0
|
||||
@ -242,6 +248,7 @@ class SFTTrainer:
|
||||
}
|
||||
|
||||
train_time = 0
|
||||
total_tokens = 0
|
||||
for epoch in range(start_epoch, self.config.trainer.total_epochs):
|
||||
self.train_sampler.set_epoch(epoch=epoch)
|
||||
|
||||
@ -302,6 +309,8 @@ class SFTTrainer:
|
||||
metrics["train/grad_norm"] = metrics.pop("grad_norm")
|
||||
metrics["train/lr"] = lr
|
||||
metrics["train/global_tokens"] = output_tensor.sum().item()
|
||||
total_tokens += metrics["train/global_tokens"]
|
||||
metrics["train/total_tokens(B)"] = total_tokens / 1e9
|
||||
# mfu
|
||||
delta_time = timer.last
|
||||
estimated_flops, promised_flops = self.flops_counter.estimate_flops(batch_seqlens, delta_time)
|
||||
@ -315,7 +324,7 @@ class SFTTrainer:
|
||||
is_save_step = global_step % self.save_freq == 0
|
||||
|
||||
# early exit or validation step
|
||||
if is_last_step or (self.test_freq > 0 and is_valid_step):
|
||||
if is_last_step and self.val_dataloader is not None or (self.test_freq > 0 and is_valid_step):
|
||||
# Perform validation
|
||||
val_losses = []
|
||||
for val_data in self.val_dataloader:
|
||||
|
@ -182,7 +182,8 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
|
||||
|
||||
tracker_file = get_checkpoint_tracker_filename(path)
|
||||
if not os.path.exists(tracker_file):
|
||||
print(f"Checkpoint tracker file does not exist: {tracker_file}")
|
||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
print(f"Checkpoint tracker file does not exist: {tracker_file}")
|
||||
return None
|
||||
|
||||
with open(tracker_file, "rb") as f:
|
||||
|
@ -1 +1 @@
|
||||
0.5.0.dev
|
||||
0.6.0
|
||||
|
@ -373,13 +373,10 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
]
|
||||
if self.config.use_kl_loss:
|
||||
select_keys.append("ref_log_prob")
|
||||
if self.config.tis_imp_ratio_cap > 0:
|
||||
assert "rollout_log_probs" in data.batch.keys(), (
|
||||
"Truncated Importance Sampling (TIS) requires to configure "
|
||||
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
|
||||
"and is not currently supported in Server mode (agent loop)."
|
||||
)
|
||||
select_keys.append("rollout_log_probs")
|
||||
# Include pre-computed IS weights if present in batch
|
||||
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
|
||||
if "rollout_is_weights" in data.batch.keys():
|
||||
select_keys.append("rollout_is_weights")
|
||||
|
||||
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
|
||||
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
|
||||
@ -412,7 +409,6 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
|
||||
response_mask = model_inputs["response_mask"]
|
||||
old_log_prob = model_inputs["old_log_probs"]
|
||||
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None
|
||||
advantages = model_inputs["advantages"]
|
||||
|
||||
entropy_coeff = self.config.entropy_coeff
|
||||
@ -438,9 +434,21 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
|
||||
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
|
||||
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
|
||||
|
||||
# Extract pre-computed rollout importance sampling weights if present
|
||||
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
|
||||
rollout_is_weights = model_inputs.get("rollout_is_weights", None)
|
||||
|
||||
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
|
||||
# are computed centrally in ray_trainer.py for consistency and efficiency.
|
||||
# This ensures metrics are computed uniformly across all batches at the trainer level
|
||||
# and avoids redundant computation across workers and micro-batches.
|
||||
|
||||
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
|
||||
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
|
||||
# Compute policy loss (all functions return 4 values)
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
@ -448,7 +456,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
response_mask=response_mask,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
config=self.config,
|
||||
rollout_log_probs=rollout_log_probs,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
if entropy_coeff != 0:
|
||||
|
@ -316,6 +316,10 @@ class MegatronPPOActor(BasePPOActor):
|
||||
]
|
||||
if self.config.use_kl_loss:
|
||||
select_keys.append("ref_log_prob")
|
||||
# Include pre-computed IS weights if present in batch
|
||||
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
|
||||
if "rollout_is_weights" in data.batch.keys():
|
||||
select_keys.append("rollout_is_weights")
|
||||
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
|
||||
if self.has_multi_modal_inputs:
|
||||
data = data.select(select_keys, ["multi_modal_inputs"])
|
||||
@ -419,7 +423,6 @@ class MegatronPPOActor(BasePPOActor):
|
||||
response_length = responses.size(1)
|
||||
response_mask = data["response_mask"].to(bool)
|
||||
loss_agg_mode = self.config.loss_agg_mode
|
||||
|
||||
# compute policy loss
|
||||
log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous()
|
||||
ret_entropy = None
|
||||
@ -434,6 +437,15 @@ class MegatronPPOActor(BasePPOActor):
|
||||
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
|
||||
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
|
||||
# Extract pre-computed rollout importance sampling weights if present
|
||||
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
|
||||
rollout_is_weights = data.get("rollout_is_weights", None)
|
||||
|
||||
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
|
||||
# are computed centrally in ray_trainer.py for consistency and efficiency.
|
||||
# This ensures metrics are computed uniformly across all batches at the trainer level
|
||||
# and avoids redundant computation across workers and micro-batches.
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
@ -441,6 +453,7 @@ class MegatronPPOActor(BasePPOActor):
|
||||
response_mask=response_mask,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
config=self.config,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
stats.update(
|
||||
|
@ -106,7 +106,6 @@ class ActorConfig(BaseConfig):
|
||||
clip_ratio_c: float = 3.0
|
||||
loss_agg_mode: str = "token-mean"
|
||||
entropy_coeff: float = 0
|
||||
tis_imp_ratio_cap: float = -1
|
||||
use_kl_loss: bool = False
|
||||
use_torch_compile: bool = True
|
||||
kl_loss_coef: float = 0.001
|
||||
|
@ -60,16 +60,27 @@ class FSDPOptimizerConfig(OptimizerConfig):
|
||||
Args:
|
||||
lr (float): Learning rate.
|
||||
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
|
||||
warmup_style (str): LR warmup style: "constant" or "cosine".
|
||||
lr_scheduler_type (str): LR scheduler type: "constant" or "cosine".
|
||||
num_cycles (float): Number of cosine cycles in LR schedule.
|
||||
"""
|
||||
|
||||
_mutable_fields = OptimizerConfig._mutable_fields.copy()
|
||||
_mutable_fields.add("lr_scheduler_type")
|
||||
|
||||
min_lr_ratio: Optional[float] = None
|
||||
warmup_style: str = "constant"
|
||||
# deprecate warmup_style
|
||||
warmup_style: Optional[str] = None
|
||||
lr_scheduler_type: str = "constant"
|
||||
num_cycles: float = 0.5
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.warmup_style in ["constant", "cosine"]
|
||||
if self.warmup_style is not None:
|
||||
assert self.warmup_style in ["constant", "cosine"]
|
||||
warnings.warn(
|
||||
"`warmup_style` is deprecated, use `lr_scheduler_type` instead.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
self.lr_scheduler_type = self.warmup_style
|
||||
assert self.lr_scheduler_type in ["constant", "cosine"]
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
|
@ -370,7 +370,7 @@ class FSDPEngine(BaseEngine):
|
||||
|
||||
total_steps = optim_config.total_training_steps
|
||||
num_warmup_steps = optim_config.lr_warmup_steps
|
||||
warmup_style = optim_config.warmup_style
|
||||
lr_scheduler_type = optim_config.lr_scheduler_type
|
||||
min_lr_ratio = optim_config.min_lr_ratio
|
||||
num_cycles = optim_config.num_cycles
|
||||
if num_warmup_steps <= 0:
|
||||
@ -380,9 +380,9 @@ class FSDPEngine(BaseEngine):
|
||||
if self.rank == 0:
|
||||
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
||||
|
||||
if warmup_style == "constant":
|
||||
if lr_scheduler_type == "constant":
|
||||
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
|
||||
elif warmup_style == "cosine":
|
||||
elif lr_scheduler_type == "cosine":
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
@ -391,7 +391,7 @@ class FSDPEngine(BaseEngine):
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
|
||||
return lr_scheduler
|
||||
|
||||
def _build_model_optimizer(self):
|
||||
|
@ -529,7 +529,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
|
||||
total_steps = optim_config.get("total_training_steps", 0)
|
||||
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
|
||||
warmup_style = optim_config.get("warmup_style", "constant")
|
||||
lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant")
|
||||
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
|
||||
num_cycles = optim_config.get("num_cycles", 0.5)
|
||||
if num_warmup_steps < 0:
|
||||
@ -539,11 +539,11 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
if self.rank == 0:
|
||||
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
||||
|
||||
if warmup_style == "constant":
|
||||
if lr_scheduler_type == "constant":
|
||||
actor_lr_scheduler = get_constant_schedule_with_warmup(
|
||||
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
|
||||
)
|
||||
elif warmup_style == "cosine":
|
||||
elif lr_scheduler_type == "cosine":
|
||||
actor_lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=actor_optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
@ -552,7 +552,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
|
||||
|
||||
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
|
||||
else:
|
||||
@ -1386,7 +1386,8 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
|
||||
total_steps = config.optim.get("total_training_steps", 0)
|
||||
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
|
||||
warmup_style = config.optim.get("warmup_style", "constant")
|
||||
|
||||
lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant")
|
||||
if num_warmup_steps < 0:
|
||||
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
|
||||
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
|
||||
@ -1396,11 +1397,11 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
|
||||
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
|
||||
|
||||
if warmup_style == "constant":
|
||||
if lr_scheduler_type == "constant":
|
||||
critic_lr_scheduler = get_constant_schedule_with_warmup(
|
||||
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
|
||||
)
|
||||
elif warmup_style == "cosine":
|
||||
elif lr_scheduler_type == "cosine":
|
||||
min_lr_ratio = config.optim.get("min_lr_ratio", 0.0)
|
||||
num_cycles = config.optim.get("num_cycles", 0.5)
|
||||
critic_lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
@ -1411,7 +1412,7 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
|
||||
|
||||
return critic_module, critic_optimizer, critic_lr_scheduler
|
||||
|
||||
|
Reference in New Issue
Block a user