mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[sglang, recipe] feat: add SGLang as rollout engine for one-step-off-policy (#3531)
### What does this PR do? This PR extends the one-step-off-policy recipe by adding SGLang as an alternative rollout engine to vLLM, allowing flexible backend selection and improving training efficiency. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/pull/3460 - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test To validate this solution, we adopted the existing experimental configuration from the recipe one-step-off-policy. The evaluation demonstrates that the proposed SGLang rollout engine integration achieves effective acceleration in one-step-off-policy asynchronous training, providing users with enhanced rollout engine options for diverse deployment scenarios. **Experimental Results** - **Machine Configuration**: 2 nodes with 16 H20 GPUs each - Generation: 4 GPUs - Training: 12 GPUs - **Model**: Qwen2.5-Math-7B - **Max Response Length**: 8,192 tokens - **Algorithm**: DAPO - **Rollout Engine**: vLLM, SGLang | training mode | engine | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | acc/maj@32/mean | |------------------------|----------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------| | colocate sync | SGLang+FSDP2 | 452 | 131 | - | 125 | 54 | 199 | 12h25m | 0.6560 | 0.4471 | | one-step-overlap async | SGLang+FSDP2 | 406 | - | 12 | 305 | 58 | 245 | 11h12m (+11%) | 0.6303 | 0.4443 | * colocate sync: step ≈ gen + old_log_prob + update_actor * one-step-overlap async: step ≈ max(wait_prev_gen + generate_sequences, old_log_prob + update_actor) <img width="1218" height="777" alt="image" src="https://github.com/user-attachments/assets/58734164-2534-492f-bf00-1e80faae0fe7" /> ### API and Usage Example **Configuration Example** ```bash # Using SGLang engine python3 -m recipe.one_step_off_policy.main_ppo \ actor_rollout_ref.rollout.name=sglang \ # ... other configuration parameters # Using vLLM engine python3 -m recipe.one_step_off_policy.main_ppo \ actor_rollout_ref.rollout.name=vllm \ # ... other configuration parameters ``` **Script Usage** ```bash # Using SGLang engine bash dapo_7b_math_fsdp2_sglang_4_12.sh bash dapo_7b_math_fsdp2_sglang_colocate.sh # Using vLLM engine bash dapo_7b_math_fsdp2_4_12.sh bash dapo_7b_math_fsdp2_colocate.sh ``` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: wuxibin <wuxibin@bytedance.com>
This commit is contained in:
@ -0,0 +1,23 @@
|
||||
hydra:
|
||||
searchpath:
|
||||
- file://verl/trainer/config
|
||||
|
||||
defaults:
|
||||
- ppo_trainer
|
||||
- _self_
|
||||
|
||||
data:
|
||||
max_prompt_length: 1024
|
||||
max_response_length: 1024
|
||||
train_batch_size: 256
|
||||
return_raw_chat: True
|
||||
shuffle: False
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
rollout:
|
||||
name: sglang
|
||||
multi_turn:
|
||||
enable: True
|
||||
max_assistant_turns: 2
|
||||
format: qwen
|
@ -293,6 +293,6 @@ python3 -m recipe.one_step_off_policy.async_main_ppo \
|
||||
| Category | Support Situation |
|
||||
|--------------------|-----------------------------------------------------------------------------------------------------------------|
|
||||
| train engine | FSDP2 <br/> Megatron |
|
||||
| rollout engine | vLLM |
|
||||
| rollout engine | vLLM <br/> SGLang |
|
||||
| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |
|
||||
| Reward | all |
|
||||
|
140
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_4_12.sh
Normal file
140
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_4_12.sh
Normal file
@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-one-step-off-4-12'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
train_prompt_bsz=512
|
||||
n_resp_per_prompt=12
|
||||
train_prompt_mini_bsz=32
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-2}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
n_gpus_rollout=2
|
||||
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
|
||||
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=2
|
||||
sp_size=4
|
||||
fsdp_size=2
|
||||
|
||||
python3 -m recipe.one_step_off_policy.main_ppo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.layered_summon=True \
|
||||
actor_rollout_ref.rollout.load_format=safetensors \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.test_freq=10 \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.total_epochs=10 \
|
||||
trainer.total_training_steps=100 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.log_val_generations=10 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${n_gpus_training}" \
|
||||
rollout.nnodes="${NNODES}" \
|
||||
rollout.n_gpus_per_node="${n_gpus_rollout}"
|
133
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_colocate.sh
Normal file
133
recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_colocate.sh
Normal file
@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-colocate'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
train_prompt_bsz=512
|
||||
n_resp_per_prompt=12
|
||||
train_prompt_mini_bsz=32
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-2}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
offload=True
|
||||
gen_tp=2
|
||||
sp_size=4
|
||||
fsdp_size=2
|
||||
|
||||
# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.layered_summon=True \
|
||||
actor_rollout_ref.rollout.load_format=safetensors \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.test_freq=10 \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.total_epochs=10 \
|
||||
trainer.total_training_steps=100 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.log_val_generations=10
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -83,13 +84,20 @@ class ActorRolloutRefWorker(ARRWorker):
|
||||
assert hasattr(self, "_weights_info") and self._weights_info is not None
|
||||
|
||||
params = self._get_actor_params() if self._is_actor else None
|
||||
rollout_name = self.config.rollout.name
|
||||
if self._is_rollout:
|
||||
inference_model = (
|
||||
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
)
|
||||
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
|
||||
if rollout_name == "vllm":
|
||||
inference_model = (
|
||||
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
)
|
||||
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
|
||||
|
||||
patch_vllm_moe_model_weight_loader(inference_model)
|
||||
patch_vllm_moe_model_weight_loader(inference_model)
|
||||
elif rollout_name == "sglang":
|
||||
inference_model = self.rollout._engine
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
|
||||
loop = asyncio.get_event_loop()
|
||||
for key, shape, dtype in self._weights_info:
|
||||
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
|
||||
if self._is_actor:
|
||||
@ -102,7 +110,23 @@ class ActorRolloutRefWorker(ARRWorker):
|
||||
|
||||
self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
|
||||
if self._is_rollout:
|
||||
inference_model.load_weights([(key, tensor)])
|
||||
if rollout_name == "vllm":
|
||||
inference_model.load_weights([(key, tensor)])
|
||||
elif rollout_name == "sglang":
|
||||
loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)]))
|
||||
|
||||
async def update_weights(self, inference_engine, params):
|
||||
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
|
||||
|
||||
await sgl_update_weights(
|
||||
engine=inference_engine,
|
||||
params_batch=params,
|
||||
device_mesh_key="infer_tp",
|
||||
device_mesh=self.rollout_device_mesh,
|
||||
)
|
||||
|
||||
if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
|
||||
await inference_engine.flush_cache()
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def get_actor_weights_info(self):
|
||||
@ -209,6 +233,7 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
rollout_device_mesh = init_device_mesh(
|
||||
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
|
||||
)
|
||||
self.rollout_device_mesh = rollout_device_mesh
|
||||
|
||||
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
|
||||
self._register_dispatch_collect_info(
|
||||
@ -216,7 +241,8 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
)
|
||||
|
||||
rollout_name = self.config.rollout.name
|
||||
assert rollout_name == "vllm"
|
||||
if rollout_name not in ("vllm", "sglang"):
|
||||
raise NotImplementedError(f"rollout_name: {rollout_name} is not supported")
|
||||
|
||||
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
|
||||
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)
|
||||
@ -227,14 +253,23 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
|
||||
)
|
||||
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
|
||||
from .vllm_sharding_manager import VLLMShardingManager
|
||||
|
||||
rollout_sharding_manager = VLLMShardingManager(
|
||||
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
|
||||
)
|
||||
if rollout_name == "vllm":
|
||||
from .vllm_sharding_manager import VLLMShardingManager
|
||||
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
rollout_sharding_manager = VLLMShardingManager(
|
||||
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
|
||||
)
|
||||
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
elif rollout_name == "sglang":
|
||||
from .sglang_sharding_manager import SGLangShardingManager
|
||||
|
||||
rollout_sharding_manager = SGLangShardingManager(device_mesh=rollout_device_mesh)
|
||||
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
|
||||
self.model_config = model_config
|
||||
self.rollout = rollout
|
||||
self.rollout_sharding_manager = rollout_sharding_manager
|
||||
|
||||
|
@ -0,0 +1,65 @@
|
||||
set -x
|
||||
|
||||
project_name='GRPO'
|
||||
exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-sglang-one-step-off-2-6'
|
||||
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-0.6B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
|
||||
|
||||
NNODES=${NNODES:-1}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
n_gpus_rollout=2
|
||||
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
|
||||
|
||||
|
||||
python3 -m recipe.one_step_off_policy.main_ppo \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.train_batch_size=1152 \
|
||||
data.max_prompt_length=512 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=192 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.n=5 \
|
||||
actor_rollout_ref.rollout.load_format=safetensors \
|
||||
actor_rollout_ref.rollout.layered_summon=True \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.val_before_train=True \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=2 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${n_gpus_training}" \
|
||||
rollout.nnodes="${NNODES}" \
|
||||
rollout.n_gpus_per_node="${n_gpus_rollout}" $@
|
70
recipe/one_step_off_policy/sglang_sharding_manager.py
Normal file
70
recipe/one_step_off_policy/sglang_sharding_manager.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from verl import DataProto
|
||||
from verl.protocol import all_gather_data_proto
|
||||
from verl.utils.debug import GPUMemoryLogger
|
||||
from verl.utils.device import get_torch_device
|
||||
from verl.utils.torch_functional import check_device_is_available
|
||||
from verl.workers.sharding_manager.base import BaseShardingManager
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class SGLangShardingManager(BaseShardingManager):
|
||||
@check_device_is_available()
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
self.device_mesh = device_mesh
|
||||
self.tp_size = self.device_mesh["infer_tp"].size()
|
||||
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
|
||||
self.timing = {}
|
||||
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
|
||||
get_torch_device().manual_seed(gen_dp_rank + 1000)
|
||||
self.gen_random_states = get_torch_device().get_rng_state()
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def __enter__(self):
|
||||
get_torch_device().set_rng_state(self.gen_random_states)
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.gen_random_states = get_torch_device().get_rng_state()
|
||||
get_torch_device().empty_cache()
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""All gather across tp group to make each rank has identical input."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
group = self.device_mesh["infer_tp"].get_group()
|
||||
|
||||
all_gather_data_proto(data=data, process_group=group)
|
||||
return data
|
||||
|
||||
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""Get chunk data of this tp rank since we do all gather in preprocess."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
return data.chunk(chunks=self.tp_size)[self.tp_rank]
|
Reference in New Issue
Block a user