mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
### What does this PR do? As title ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
133 lines
5.5 KiB
Bash
133 lines
5.5 KiB
Bash
#!/usr/bin/env bash
|
|
set -xeuo pipefail
|
|
|
|
project_name='DAPO'
|
|
exp_name='DAPO-Qwen2.5-7b-MATH-megatron-0519a1'
|
|
|
|
adv_estimator=grpo
|
|
|
|
use_kl_in_reward=False
|
|
kl_coef=0.0
|
|
use_kl_loss=False
|
|
kl_loss_coef=0.0
|
|
|
|
clip_ratio_low=0.2
|
|
clip_ratio_high=0.28
|
|
|
|
max_prompt_length=$((1024 * 2))
|
|
max_response_length=$((1024 * 8))
|
|
enable_overlong_buffer=True
|
|
overlong_buffer_len=$((1024 * 4))
|
|
overlong_penalty_factor=1.0
|
|
|
|
loss_agg_mode="token-mean"
|
|
|
|
train_prompt_bsz=512
|
|
n_resp_per_prompt=16
|
|
train_prompt_mini_bsz=32
|
|
|
|
# Ray
|
|
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
|
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
|
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
|
NNODES=${NNODES:-4}
|
|
# Paths
|
|
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
|
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
|
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
|
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
|
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
|
|
|
# Algorithm
|
|
temperature=1.0
|
|
top_p=1.0
|
|
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
|
val_top_p=0.7
|
|
|
|
# Performance Related Parameter
|
|
use_dynamic_bsz=True
|
|
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
|
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
|
offload=True
|
|
gen_tp=4
|
|
train_tp=4
|
|
train_pp=2
|
|
|
|
# TODO: support dynamic_bsz for megatron
|
|
# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
|
# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
|
# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
|
# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
|
# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
|
# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
|
|
|
python3 -m verl.trainer.main_ppo \
|
|
--config-path=config \
|
|
--config-name='ppo_megatron_trainer.yaml' \
|
|
data.train_files="${TRAIN_FILE}" \
|
|
data.val_files="${TEST_FILE}" \
|
|
data.prompt_key=prompt \
|
|
data.truncation='left' \
|
|
data.max_prompt_length=${max_prompt_length} \
|
|
data.max_response_length=${max_response_length} \
|
|
data.train_batch_size=${train_prompt_bsz} \
|
|
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
|
algorithm.adv_estimator=${adv_estimator} \
|
|
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
|
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
|
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
|
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
|
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
|
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
|
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
|
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
|
|
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
|
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
|
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
|
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
|
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
|
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
|
actor_rollout_ref.actor.megatron.param_offload=${offload} \
|
|
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
|
|
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
|
|
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
|
|
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
|
|
actor_rollout_ref.actor.entropy_coeff=0 \
|
|
actor_rollout_ref.actor.optim.clip_grad=1.0 \
|
|
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
|
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
|
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
|
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
|
actor_rollout_ref.rollout.temperature=${temperature} \
|
|
actor_rollout_ref.rollout.top_p=${top_p} \
|
|
actor_rollout_ref.rollout.top_k=${top_k} \
|
|
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
|
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
|
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
|
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
|
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
|
actor_rollout_ref.rollout.name=vllm \
|
|
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
|
|
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
|
|
actor_rollout_ref.ref.megatron.param_offload=${offload} \
|
|
reward_model.reward_manager=dapo \
|
|
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
|
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
|
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
|
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
|
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
|
trainer.logger='["console","wandb"]' \
|
|
trainer.project_name="${project_name}" \
|
|
trainer.experiment_name="${exp_name}" \
|
|
trainer.n_gpus_per_node=16 \
|
|
trainer.nnodes="${NNODES}" \
|
|
trainer.val_before_train=False \
|
|
trainer.test_freq=10 \
|
|
trainer.save_freq=10 \
|
|
trainer.total_epochs=10 \
|
|
trainer.default_local_dir="${CKPTS_DIR}" \
|
|
trainer.resume_mode=auto \
|
|
trainer.log_val_generations=10
|