[sglang, recipe] feat: add SGLang as rollout engine for one-step-off-policy (#3531)

### What does this PR do?

This PR extends the one-step-off-policy recipe by adding SGLang as an
alternative rollout engine to vLLM, allowing flexible backend selection
and improving training efficiency.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here:
https://github.com/volcengine/verl/pull/3460
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

To validate this solution, we adopted the existing experimental
configuration from the recipe one-step-off-policy.

The evaluation demonstrates that the proposed SGLang rollout engine
integration achieves effective acceleration in one-step-off-policy
asynchronous training, providing users with enhanced rollout engine
options for diverse deployment scenarios.

**Experimental Results**

- **Machine Configuration**: 2 nodes with 16 H20 GPUs each
    - Generation: 4 GPUs
    - Training: 12 GPUs
- **Model**: Qwen2.5-Math-7B
- **Max Response Length**: 8,192 tokens
- **Algorithm**: DAPO
- **Rollout Engine**: vLLM, SGLang

| training mode | engine | step | gen | wait_prev_gen |
generate_sequences | old_log_prob | update_actor | total time |
acc/best@32/mean | acc/maj@32/mean |

|------------------------|----------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|
| colocate sync | SGLang+FSDP2 | 452 | 131 | - | 125 | 54 | 199 | 12h25m
| 0.6560 | 0.4471 |
| one-step-overlap async | SGLang+FSDP2 | 406 | - | 12 | 305 | 58 | 245
| 11h12m (+11%) | 0.6303 | 0.4443 |

* colocate sync: step ≈ gen + old_log_prob + update_actor
* one-step-overlap async: step ≈ max(wait_prev_gen + generate_sequences,
old_log_prob + update_actor)

<img width="1218" height="777" alt="image"
src="https://github.com/user-attachments/assets/58734164-2534-492f-bf00-1e80faae0fe7"
/>

### API and Usage Example

**Configuration Example**
```bash
# Using SGLang engine
python3 -m recipe.one_step_off_policy.main_ppo \
    actor_rollout_ref.rollout.name=sglang \
    # ... other configuration parameters

# Using vLLM engine
python3 -m recipe.one_step_off_policy.main_ppo \
    actor_rollout_ref.rollout.name=vllm \
    # ... other configuration parameters
```

**Script Usage**
```bash
# Using SGLang engine
bash dapo_7b_math_fsdp2_sglang_4_12.sh
bash dapo_7b_math_fsdp2_sglang_colocate.sh

# Using vLLM engine
bash dapo_7b_math_fsdp2_4_12.sh
bash dapo_7b_math_fsdp2_colocate.sh
```

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: wuxibin <wuxibin@bytedance.com>
This commit is contained in:
KAMiPan
2025-10-14 10:48:29 +08:00
committed by GitHub
parent 5d378b5f95
commit 3abcc09d44
7 changed files with 479 additions and 13 deletions

View File

@ -0,0 +1,23 @@
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
data:
max_prompt_length: 1024
max_response_length: 1024
train_batch_size: 256
return_raw_chat: True
shuffle: False
actor_rollout_ref:
hybrid_engine: True
rollout:
name: sglang
multi_turn:
enable: True
max_assistant_turns: 2
format: qwen

View File

@ -293,6 +293,6 @@ python3 -m recipe.one_step_off_policy.async_main_ppo \
| Category | Support Situation |
|--------------------|-----------------------------------------------------------------------------------------------------------------|
| train engine | FSDP2 <br/> Megatron |
| rollout engine | vLLM |
| rollout engine | vLLM <br/> SGLang |
| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |
| Reward | all |

View File

@ -0,0 +1,140 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-one-step-off-4-12'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=12
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=2
sp_size=4
fsdp_size=2
python3 -m recipe.one_step_off_policy.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}"

View File

@ -0,0 +1,133 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-colocate'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=12
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
offload=True
gen_tp=2
sp_size=4
fsdp_size=2
# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361
python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
@ -83,13 +84,20 @@ class ActorRolloutRefWorker(ARRWorker):
assert hasattr(self, "_weights_info") and self._weights_info is not None
params = self._get_actor_params() if self._is_actor else None
rollout_name = self.config.rollout.name
if self._is_rollout:
inference_model = (
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
if rollout_name == "vllm":
inference_model = (
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
patch_vllm_moe_model_weight_loader(inference_model)
patch_vllm_moe_model_weight_loader(inference_model)
elif rollout_name == "sglang":
inference_model = self.rollout._engine
else:
raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
loop = asyncio.get_event_loop()
for key, shape, dtype in self._weights_info:
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
if self._is_actor:
@ -102,7 +110,23 @@ class ActorRolloutRefWorker(ARRWorker):
self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
if self._is_rollout:
inference_model.load_weights([(key, tensor)])
if rollout_name == "vllm":
inference_model.load_weights([(key, tensor)])
elif rollout_name == "sglang":
loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)]))
async def update_weights(self, inference_engine, params):
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
await sgl_update_weights(
engine=inference_engine,
params_batch=params,
device_mesh_key="infer_tp",
device_mesh=self.rollout_device_mesh,
)
if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
await inference_engine.flush_cache()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self):
@ -209,6 +233,7 @@ class RolloutWorker(ActorRolloutRefWorker):
rollout_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
self.rollout_device_mesh = rollout_device_mesh
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
@ -216,7 +241,8 @@ class RolloutWorker(ActorRolloutRefWorker):
)
rollout_name = self.config.rollout.name
assert rollout_name == "vllm"
if rollout_name not in ("vllm", "sglang"):
raise NotImplementedError(f"rollout_name: {rollout_name} is not supported")
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)
@ -227,14 +253,23 @@ class RolloutWorker(ActorRolloutRefWorker):
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
from .vllm_sharding_manager import VLLMShardingManager
rollout_sharding_manager = VLLMShardingManager(
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
)
if rollout_name == "vllm":
from .vllm_sharding_manager import VLLMShardingManager
log_gpu_memory_usage("After building sharding manager", logger=logger)
rollout_sharding_manager = VLLMShardingManager(
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif rollout_name == "sglang":
from .sglang_sharding_manager import SGLangShardingManager
rollout_sharding_manager = SGLangShardingManager(device_mesh=rollout_device_mesh)
log_gpu_memory_usage("After building sharding manager", logger=logger)
self.model_config = model_config
self.rollout = rollout
self.rollout_sharding_manager = rollout_sharding_manager

View File

@ -0,0 +1,65 @@
set -x
project_name='GRPO'
exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-sglang-one-step-off-2-6'
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-0.6B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
NNODES=${NNODES:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
python3 -m recipe.one_step_off_policy.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.train_batch_size=1152 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=192 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.val_before_train=True \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=2 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}" $@

View File

@ -0,0 +1,70 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from torch.distributed.device_mesh import DeviceMesh
from verl import DataProto
from verl.protocol import all_gather_data_proto
from verl.utils.debug import GPUMemoryLogger
from verl.utils.device import get_torch_device
from verl.utils.torch_functional import check_device_is_available
from verl.workers.sharding_manager.base import BaseShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class SGLangShardingManager(BaseShardingManager):
@check_device_is_available()
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
self.tp_size = self.device_mesh["infer_tp"].size()
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
self.timing = {}
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
get_torch_device().manual_seed(gen_dp_rank + 1000)
self.gen_random_states = get_torch_device().get_rng_state()
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def __enter__(self):
get_torch_device().set_rng_state(self.gen_random_states)
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().empty_cache()
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data
# TODO: Current impl doesn't consider FSDP with torch micro-dp
group = self.device_mesh["infer_tp"].get_group()
all_gather_data_proto(data=data, process_group=group)
return data
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size == 1:
return data
return data.chunk(chunks=self.tp_size)[self.tp_rank]