mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
# Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
165 lines
6.9 KiB
Bash
165 lines
6.9 KiB
Bash
#!/usr/bin/env bash
|
|
set -xeuo pipefail
|
|
|
|
# Rollout Importance Sampling Example
|
|
# References:
|
|
# - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
|
# - Off-policy RL: https://fengyao.notion.site/off-policy-rl
|
|
|
|
project_name='DAPO'
|
|
exp_name='DAPO-Qwen2.5-32B-RolloutIS' # Rollout Importance Sampling
|
|
|
|
adv_estimator=grpo
|
|
|
|
use_kl_in_reward=False
|
|
kl_coef=0.0
|
|
use_kl_loss=False
|
|
kl_loss_coef=0.0
|
|
|
|
# Rollout Importance Sampling parameters (matches original TIS with threshold=2)
|
|
rollout_is=True
|
|
rollout_is_threshold=2.0
|
|
rollout_is_threshold_lower=null # No lower bound (original TIS behavior)
|
|
rollout_is_level=token # token-level (original TIS behavior)
|
|
rollout_is_mode=truncate # truncate mode (original TIS behavior)
|
|
rollout_is_veto_threshold=null # No veto (original TIS behavior)
|
|
|
|
clip_ratio_low=0.2
|
|
clip_ratio_high=0.28
|
|
|
|
max_prompt_length=$((1024 * 2))
|
|
max_response_length=$((1024 * 20))
|
|
enable_overlong_buffer=True
|
|
overlong_buffer_len=$((1024 * 4))
|
|
overlong_penalty_factor=1.0
|
|
|
|
loss_agg_mode="token-mean"
|
|
|
|
enable_filter_groups=True
|
|
filter_groups_metric=acc
|
|
max_num_gen_batches=10
|
|
train_prompt_bsz=512
|
|
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
|
n_resp_per_prompt=16
|
|
train_prompt_mini_bsz=32
|
|
|
|
# Ray
|
|
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
|
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
|
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
|
NNODES=${NNODES:-16}
|
|
# Paths
|
|
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
|
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
|
|
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
|
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
|
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
|
|
|
# Algorithm
|
|
temperature=1.0
|
|
top_p=1.0
|
|
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
|
val_top_p=0.7
|
|
|
|
# Performance Related Parameter
|
|
sp_size=8
|
|
use_dynamic_bsz=True
|
|
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
|
|
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
|
|
offload=True
|
|
gen_tp=4
|
|
|
|
|
|
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
|
|
#
|
|
# Please note that server mode (agent loop) hasn't returned rollout_log_probs for now,
|
|
# so currently server mode is not supported for Rollout IS.
|
|
#
|
|
# Rollout IS parameters (configured at top of script):
|
|
# algorithm.rollout_is=True
|
|
# algorithm.rollout_is_threshold=2.0 # Upper threshold (can be tuned)
|
|
# algorithm.rollout_is_level=token # Aggregation level
|
|
# algorithm.rollout_is_mode=truncate # Bounding mode
|
|
# actor_rollout_ref.rollout.calculate_log_probs=True # Required!
|
|
|
|
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
|
|
--working-dir "${WORKING_DIR}" \
|
|
-- python3 -m recipe.dapo.main_dapo \
|
|
data.train_files="${TRAIN_FILE}" \
|
|
data.val_files="${TEST_FILE}" \
|
|
data.prompt_key=prompt \
|
|
data.truncation='left' \
|
|
data.max_prompt_length=${max_prompt_length} \
|
|
data.max_response_length=${max_response_length} \
|
|
data.gen_batch_size=${gen_prompt_bsz} \
|
|
data.train_batch_size=${train_prompt_bsz} \
|
|
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
|
algorithm.adv_estimator=${adv_estimator} \
|
|
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
|
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
|
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
|
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
|
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
|
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
|
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
|
algorithm.filter_groups.enable=${enable_filter_groups} \
|
|
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
|
|
algorithm.filter_groups.metric=${filter_groups_metric} \
|
|
actor_rollout_ref.model.use_remove_padding=True \
|
|
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
|
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
|
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
|
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
|
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
|
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
|
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
|
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
|
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
|
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
|
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
|
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
|
actor_rollout_ref.actor.entropy_coeff=0 \
|
|
actor_rollout_ref.actor.grad_clip=1.0 \
|
|
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
|
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
|
algorithm.rollout_is=${rollout_is} \
|
|
algorithm.rollout_is_threshold=${rollout_is_threshold} \
|
|
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
|
|
algorithm.rollout_is_level=${rollout_is_level} \
|
|
algorithm.rollout_is_mode=${rollout_is_mode} \
|
|
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
|
|
actor_rollout_ref.rollout.calculate_log_probs=True \
|
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
|
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
|
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
|
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
|
actor_rollout_ref.rollout.temperature=${temperature} \
|
|
actor_rollout_ref.rollout.top_p=${top_p} \
|
|
actor_rollout_ref.rollout.top_k="${top_k}" \
|
|
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
|
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
|
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
|
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
|
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
|
actor_rollout_ref.rollout.name=vllm \
|
|
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
|
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
|
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
|
reward_model.reward_manager=dapo \
|
|
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
|
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
|
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
|
trainer.logger='["console","wandb"]' \
|
|
trainer.project_name="${project_name}" \
|
|
trainer.experiment_name="${exp_name}" \
|
|
trainer.n_gpus_per_node=8 \
|
|
trainer.nnodes="${NNODES}" \
|
|
trainer.val_before_train=True \
|
|
trainer.test_freq=5 \
|
|
trainer.save_freq=5 \
|
|
trainer.total_epochs=1 \
|
|
trainer.default_local_dir="${CKPTS_DIR}" \
|
|
trainer.resume_mode=auto
|