[sglang] refactor: Unify async rollout under SGLangRollout, and support sglang==0.4.6.post5 (#1717)

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

- Unify the functionality of SGLangRollout and AsyncSGLangRollout,
remove original SGLangRollout and rename AsyncSGLangRollout to
SGLangRollout.
- Make trivial changes due to modification in sglang==0.4.6.post5.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluatuion results, etc.

### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add CI test(s) if necessary.

---------

Co-authored-by: zyzshishui <@qq.com>
Co-authored-by: Xiang Long <mindsculptor@yeah.net>
Co-authored-by: ocss884 <ocss.lin@gmail.com>
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
Co-authored-by: H <linhaibin.eric@gmail.com>
This commit is contained in:
Yuzhen Zhou
2025-05-31 19:47:25 -07:00
committed by GitHub
parent cef6361def
commit 4de247fe4d
37 changed files with 1124 additions and 1665 deletions

View File

@ -204,7 +204,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -222,7 +222,7 @@ jobs:
ray stop --force
ENGINE=sglang bash tests/e2e/ppo_trainer/run_function_reward.sh
e2e_ppo_trainer_sglang_async:
e2e_ppo_trainer_sglang_multiturn_with_tool:
runs-on: [L20x8]
needs: pre_commit_for_ppo
timeout-minutes: 40 # Increase this timeout value as needed
@ -233,36 +233,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install -e .[test,gpu,sglang] --no-deps
- name: Prepare gsm8k dataset
run: |
ray stop --force
python3 examples/data_preprocess/gsm8k.py
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async
run: |
ray stop --force
ENGINE=sglang_async bash tests/e2e/ppo_trainer/run_function_reward.sh
e2e_ppo_trainer_sglang_async_with_tool:
runs-on: [L20x8]
needs: pre_commit_for_ppo
timeout-minutes: 40 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -275,7 +246,7 @@ jobs:
run: |
ray stop --force
python3 examples/data_preprocess/gsm8k_multiturn_w_tool.py --local_dir $HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed
- name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async
- name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang
run: |
ray stop --force
bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh
@ -295,7 +266,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=50g # Visual dataloader requires large memory
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -367,7 +338,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=50g # Visual dataloader requires large memory
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -50,7 +50,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -92,7 +92,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -134,7 +134,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -167,7 +167,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -206,7 +206,7 @@ jobs:
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -56,7 +56,7 @@ jobs:
HF_HUB_ENABLE_HF_TRANSFER: 1
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True"
container:
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -73,11 +73,7 @@ jobs:
- name: Test the latest SGLang
run: |
cd tests/workers/rollout
torchrun --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_sglang_spmd.py
- name: Test the latest SGLang async
run: |
cd tests/workers/rollout
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_spmd.py
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_spmd.py
- name: Test the latest SGLang Rollout async with tool
run: |
cd tests/workers/rollout

View File

@ -6,7 +6,7 @@
# Support - Traing: fsdp; Inference: vllm
# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4
# Support - Traing: fsdp; Inference: vllm, sglang
FROM lmsysorg/sglang:v0.4.6.post4-rocm630
FROM lmsysorg/sglang:v0.4.6.post5-rocm630
# Set working directory
# WORKDIR $PWD/app

View File

@ -36,8 +36,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \
python -m pip install --upgrade pip
# Install sglang-0.4.6.post4 and torch-memory-saver
RUN pip install "sglang[all]==0.4.6.post4" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
# Install sglang-0.4.6.post5 and torch-memory-saver
RUN pip uninstall -y cuda-python && pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
# Install torch-2.6.0
RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \

View File

@ -56,12 +56,12 @@ RUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvi
update-alternatives --set cuda /usr/local/cuda-12.4 && \
rm -rf /usr/local/cuda-12.6
# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post4
# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post5
# torch-2.6.0+cu124: cxx11abi=False
# torch-2.6.0+cu126: cxx11abi=True
# see https://github.com/flashinfer-ai/flashinfer/issues/911
# Install sglang-0.4.6.post1 and torch-memory-saver
RUN pip install "sglang[all]==0.4.6.post1" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
RUN pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
RUN pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata

View File

@ -22,7 +22,7 @@ docker/Dockerfile.rocm
# Support - Traing: fsdp; Inference: vllm
# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4
# Support - Traing: fsdp; Inference: vllm, sglang
FROM lmsysorg/sglang:v0.4.6.post4-rocm630
FROM lmsysorg/sglang:v0.4.6.post5-rocm630
# Set working directory
# WORKDIR $PWD/app

View File

@ -11,9 +11,9 @@ To enable multi-turn rollout, make sure to configure the following fields in you
actor_rollout_ref:
rollout:
multi_turn: True
name: "sglang_async"
name: "sglang"
These configuration activates the sglang_async engine for multi-turn interaction during rollout.
These configuration activates the sglang engine for multi-turn interaction during rollout.
Custom Tool Configuration
~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -42,7 +42,7 @@ For vLLM with Megatron or FSDP, please use the stable version of image ``whatcan
For latest vLLM with FSDP, please refer to ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``.
For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4`` which is provided by SGLang RL Group.
For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.
See files under ``docker/`` for NGC-based image or if you want to build your own.
@ -79,7 +79,7 @@ See files under ``docker/`` for NGC-based image or if you want to build your own
- **Flash Attenttion**: 2.7.4.post1
- **Flash Infer**: 0.2.2.post1
- **vLLM**: 0.8.5
- **SGLang**: 0.4.6.post4
- **SGLang**: 0.4.6.post5
- **Megatron-LM**: core_v0.12.0
- **TransformerEngine**: 2.3
- **Ray**: 2.44.1

View File

@ -21,7 +21,7 @@ Please always follow the following command to install SGLang with verl.
.. code-block:: bash
pip install --upgrade pip
# Currently 0.4.6.post4, subject to updates at any time, please refer to the latest version specified in `setup.py`
# Currently 0.4.6.post5, subject to updates at any time, please refer to the latest version specified in `setup.py`
pip install -e ".[sglang]"
You can check the following dependencies are in your environment:
@ -31,8 +31,8 @@ You can check the following dependencies are in your environment:
- **PyTorch**: 2.6.0+cu124
- **CUDA**: 12.4
- **flashinfer-python**: 0.2.5+cu124torch2.6
- **sgLang**: 0.4.6.post4
- **sgl-kernel**: 0.1.2.post1
- **sgLang**: 0.4.6.post5
- **sgl-kernel**: 0.1.4
Using SGLang as the Inference Backend for PPO Training on a Single Machine
-------------------------------------------------------------------------
@ -87,7 +87,7 @@ Why export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK?
1. ``verl`` initializes a ``SGLangRollout`` module during rollout, which is used to evaluate/generate samples.
2. ``SGLangRollout`` will initialize ``VerlEngine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP).
2. ``SGLangRollout`` will initialize ``Engine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP).
3. ``DeviceMesh.init()`` internally checks the free GPU memory of all participating devices. If the difference is too large (more than ~10%), it directly reports an error to avoid initialization failures or deadlocks.
@ -111,7 +111,7 @@ Early workers already use up GPU memory → late workers still have empty memory
**3. SGLang's TP init uses "all-device broadcast", but there's no uniform release timing**
Although ``SGLangRollout`` may only involve subset of GPUs, its ``VerlEngine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so:
Although ``SGLangRollout`` may only involve subset of GPUs, its ``Engine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so:
- Non-rollout GPUs also join the communication.
- Later on, ``DeviceMesh`` init will fail due to "inconsistent memory".

View File

@ -15,7 +15,7 @@ data:
actor_rollout_ref:
hybrid_engine: True
rollout:
name: sglang_async
name: sglang
multi_turn:
enable: True
max_turns: 5

View File

@ -15,7 +15,7 @@ data:
actor_rollout_ref:
hybrid_engine: True
rollout:
name: sglang_async
name: sglang
multi_turn:
enable: True
max_turns: 5

View File

@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang_async \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
@ -41,7 +41,7 @@ python3 -m verl.trainer.main_ppo \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='gsm8k_async_rl' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \

View File

@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang_async \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \

View File

@ -45,7 +45,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang_async \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
@ -53,7 +53,7 @@ python3 -m verl.trainer.main_ppo \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='gsm8k_async_rl' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \

View File

@ -17,6 +17,6 @@ torchdata
torchvision
transformers
wandb
sglang[all]==0.4.6.post4
sglang[all]==0.4.6.post5
torch-memory-saver>=0.0.5
huggingface_hub
huggingface_hub

View File

@ -49,7 +49,12 @@ GEO_REQUIRES = ["mathruler"]
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.5"]
SGLANG_REQUIRES = ["tensordict<=0.6.2", "sglang[srt,openai]==0.4.6.post4", "torch-memory-saver>=0.0.5", "torch==2.6.0"]
SGLANG_REQUIRES = [
"tensordict<=0.6.2",
"sglang[srt,openai]==0.4.6.post5",
"torch-memory-saver>=0.0.5",
"torch==2.6.0",
]
extras_require = {
"test": TEST_REQUIRES,

View File

@ -36,7 +36,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang_async \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
@ -46,7 +46,7 @@ python3 -m verl.trainer.main_ppo \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='gsm8k_async_rl' \
trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \
trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \

View File

@ -78,7 +78,7 @@ if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then
CHECKPOINT_CONTENTS=['model','optimizer','extra']
fi
ENGINES=("vllm" "sglang_async")
ENGINES=("vllm" "sglang")
exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal"

View File

@ -1,3 +1,4 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,7 +21,6 @@ from omegaconf import DictConfig
@patch.dict(
"sys.modules",
{
"verl.workers.rollout.sglang_rollout.async_sglang_rollout": MagicMock(AsyncSGLangRollout=MagicMock()),
"verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()),
},
)

View File

@ -32,7 +32,7 @@ from verl.protocol import DataProto
from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema
from verl.tools.search_tool import SearchTool
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
DEFAULT_USER_CONTENT_PREFIX = (
"Answer the given question. You must conduct reasoning inside <think> and </think> "
@ -143,11 +143,11 @@ class TestRolloutWithSearchTools:
prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index})
return prompts
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tools_registration(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config):
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
assert len(rollout._tool_schemas) == 1
assert "search" in rollout._tool_map.keys()
from verl.tools.search_tool import SearchTool
@ -156,11 +156,11 @@ class TestRolloutWithSearchTools:
# depend on the tokenizer
assert rollout._tool_call_parser_type == "qwen25"
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto):
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)
assert len(req_list) == 1
assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING
@ -186,12 +186,12 @@ class TestRolloutWithSearchTools:
),
)
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data):
search_rollout_config.multi_turn.max_turns = 1
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
req.finalize = MagicMock()
@ -223,9 +223,9 @@ class TestRolloutWithSearchTools:
)
@patch.object(SearchTool, "execute", new_callable=AsyncMock)
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data):
_, expect_turn_array, tool_return_array = search_data
@ -233,7 +233,7 @@ class TestRolloutWithSearchTools:
mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array]
search_rollout_config.multi_turn.max_turns = 10
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout._tool_map["search"].retrieval_service_url = "mock://dummy"
@ -272,9 +272,9 @@ class TestRolloutWithSearchTools:
assert search_counter == 2
@patch.object(SearchTool, "execute", new_callable=AsyncMock)
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data):
_, expect_turn_array, tool_return_array = search_data
@ -285,7 +285,7 @@ class TestRolloutWithSearchTools:
] * 100
search_rollout_config.multi_turn.max_turns = 10
rollout = AsyncSGLangRollout(
rollout = SGLangRollout(
actor_module="",
config=search_rollout_config,
tokenizer=qwen_tokenizer,
@ -327,7 +327,7 @@ class TestRolloutWithSearchTools:
req_turns_counter[_req.batch_data_id] += 1
return await fut
with patch.object(AsyncSGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
rollout._tp_rank = 0
loop = asyncio.get_event_loop()
output_req_list = loop.run_until_complete(asyncio.gather(*[rollout._async_rollout_a_request(r, True, False) for r in req_list]))

View File

@ -32,7 +32,7 @@ from verl.protocol import DataProto
from verl.tools.sandbox_fusion_tools import TokenBucketWorker
from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
sandbox_url = ""
@ -200,11 +200,11 @@ class TestRolloutWithTools:
prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index})
return prompts
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tools_registration(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config):
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
assert len(rollout._tool_schemas) == 1
assert "code_interpreter" in rollout._tool_map.keys()
from verl.tools.sandbox_fusion_tools import SandboxFusionTool
@ -212,11 +212,11 @@ class TestRolloutWithTools:
assert isinstance(rollout._tool_map["code_interpreter"], SandboxFusionTool)
assert rollout._tool_call_parser_type == "qwen25"
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto):
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
req_list = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)
assert len(req_list) == 1
assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING
@ -242,12 +242,12 @@ class TestRolloutWithTools:
),
)
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
sandbox_fusion_rollout_config.multi_turn.max_turns = 1
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
req.finalize = MagicMock()
@ -279,12 +279,12 @@ class TestRolloutWithTools:
)
@skip_if_valid_sandbox(sandbox_url)
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
sandbox_fusion_rollout_config.multi_turn.max_turns = 10
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
@ -323,12 +323,12 @@ class TestRolloutWithTools:
assert code_counter == 2
@skip_if_valid_sandbox(sandbox_url)
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tool_call_batch_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
sandbox_fusion_rollout_config.multi_turn.max_turns = 10
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
req_nums = 100
@ -357,7 +357,7 @@ class TestRolloutWithTools:
re = await result
return re
with patch.object(AsyncSGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
rollout._tp_rank = 0
loop = asyncio.get_event_loop()
output_req_list = loop.run_until_complete(

View File

@ -35,8 +35,8 @@ from utils_sglang import (
)
from verl import DataProto
from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
def test_async_sglang_rollout_w_tool():
@ -78,9 +78,9 @@ def test_async_sglang_rollout_w_tool():
)
rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, None)
rollout = AsyncSGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config)
rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config)
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
rollout_sharding_manager = FSDPSGLangShardingManager(
module=fsdp_model,
inference_engine=rollout._engine,
model_config=actor_model.config,
@ -111,7 +111,7 @@ def test_async_sglang_rollout_w_tool():
prompts = rollout_sharding_manager.preprocess_data(prompts)
# log_gpu_memory_usage("Before generating sequences", logger=None)
output = rollout.generate_sequences_with_tools(prompts=prompts)
output = rollout.generate_sequences(prompts=prompts)
print(f"generated {output.batch['responses'].shape=}")
# log_gpu_memory_usage("After generating sequences", logger=None)
output = rollout_sharding_manager.postprocess_data(output)

View File

@ -1,113 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
usage: torchrun --standalone --nnodes=1 \
--nproc_per_node=2 $(which pytest) \
-s test_sglang_async_spmd.py
"""
import asyncio
import torch
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.utils import broadcast_pyobj
from torch.distributed.device_mesh import init_device_mesh
from utils_sglang import (
are_lists_similar,
clean_torchelastic_env,
generate_hf_output,
initialize_global_process_group,
load_tokenizer_and_model,
prepare_inputs,
)
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
def test_sglang_spmd():
assert torch.cuda.device_count() >= 2
initialize_global_process_group(spmd=True)
clean_torchelastic_env()
max_prompt_length = 16
max_response_length = 16
local_model_path = "Qwen/Qwen2.5-0.5B"
tokenizer, actor_model = load_tokenizer_and_model(local_model_path)
preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"]
input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)
hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)
tensor_parallel_size = 2
inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
tp_rank = inference_device_mesh_cpu["tp"].get_local_rank()
if tp_rank == 0:
llm = Engine(
model_path=local_model_path,
dtype="bfloat16",
mem_fraction_static=0.5,
enable_memory_saver=True,
tp_size=inference_device_mesh_cpu["tp"].size(),
)
input_ids = input_ids.cuda()
idx_list = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
for i in range(input_ids.shape[0]):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
sampling_params = dict(
n=1,
temperature=0,
top_p=1,
top_k=-1,
max_new_tokens=max_response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
skip_special_tokens=True,
spaces_between_special_tokens=True,
ignore_eos=False,
)
loop = asyncio.get_event_loop()
outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))
else:
outputs = None
[outputs] = broadcast_pyobj(
[outputs],
rank=inference_device_mesh_cpu["tp"].get_local_rank(),
src=inference_device_mesh_cpu["tp"].mesh[0].item(),
dist_group=inference_device_mesh_cpu["tp"].get_group(),
force_cpu_device=False,
)
sglang_response_tokens = [output["text"] for output in outputs]
print(f"sglang response: {sglang_response_tokens}")
assert are_lists_similar(hf_response_tokens, sglang_response_tokens)
print("SPMD Test Passed!")
torch.distributed.barrier()
torch.distributed.destroy_process_group()

View File

@ -12,188 +12,102 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
"""
usage: torchrun --standalone --nnodes=1 \
--nproc_per_node=2 $(which pytest) \
-s test_sglang_async_spmd.py
"""
import asyncio
import torch
from sglang.srt.entrypoints.verl_engine import VerlEngine
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.utils import broadcast_pyobj
from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from verl.utils.torch_functional import pad_sequence_to_length
def levenshtein(s1, s2):
m, n = len(s1), len(s2)
# Initialize matrix of zeros
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Initialize first column and first row of the matrix
for i in range(m + 1):
dp[i][0] = i # Deletion from s1 to empty string
for j in range(n + 1):
dp[0][j] = j # Insertion to s1 from empty string
# Compute the Levenshtein distance matrix
for i in range(1, m + 1):
for j in range(1, n + 1):
cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match
dp[i][j] = min(
dp[i - 1][j] + 1, # Deletion
dp[i][j - 1] + 1, # Insertion
dp[i - 1][j - 1] + cost, # Substitution
)
return dp[m][n]
def are_lists_similar(a, b):
if len(a) != len(b):
print("The lists are of different lengths.")
return False
total_length = 0
total_diff = 0
for s1, s2 in zip(a, b):
max_len = max(len(s1), len(s2))
total_length += max_len
diff = levenshtein(s1, s2)
total_diff += diff
print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n")
percentage_difference = (total_diff / total_length) * 100
print(f"Total difference: {percentage_difference:.2f}%")
return percentage_difference <= 10
def initialize_global_process_group(timeout_second=36000):
from datetime import timedelta
import torch.distributed
# NOTE MODIFIED should provide backend=None to have nccl+gloo
# torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
return local_rank, rank, world_size
def test_sglang_spmd():
assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests."
initialize_global_process_group()
# fill rollout config
max_prompt_length = 16
max_response_length = 16
# Initialize model and token
local_cache_path = "~/.cache/verl/rlhf"
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = "Qwen/Qwen2-7B-Instruct"
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
preencode_prompts = [
"Who won the Champions League in 2019?",
"The founder of Apple is",
"What's your name?",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True)
input_ids = prompts["input_ids"]
attention_mask = prompts["attention_mask"]
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
actor_model = AutoModelForCausalLM.from_pretrained(local_model_path)
actor_model.to(torch.bfloat16)
sampling_params = dict(
n=1,
temperature=0,
top_p=1,
top_k=-1,
max_new_tokens=max_response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
skip_special_tokens=True,
spaces_between_special_tokens=True,
ignore_eos=False,
)
tensor_parallel_size = 4
device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
inference_device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]
print("building sglang rollout engine")
llm = VerlEngine(
model_path=local_model_path,
dtype="bfloat16",
mem_fraction_static=0.5,
device_mesh_cpu=inference_device_mesh_cpu["tp"],
base_gpu_id=0,
gpu_id_step=1,
)
llm.release_memory_occupation()
print("start generation")
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
batch_size = input_ids.size(0)
generation_config = GenerationConfig(do_sample=False)
actor_model.cuda()
output = actor_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_response_length,
# max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False,
) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]
hf_response_tokens = tokenizer.batch_decode(response)
print(f"hf response: {hf_response_tokens}")
print(f"{sampling_params=}")
idx_list = []
batch_size = input_ids.shape[0]
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
for i in range(batch_size):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params)
sglang_response_tokens = []
for output in outputs:
print(f"{output=}")
generated_text = output["text"]
sglang_response_tokens.append(generated_text)
print(f"sglang response: {sglang_response_tokens}")
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n"
print("Check Pass")
from utils_sglang import (
are_lists_similar,
clean_torchelastic_env,
generate_hf_output,
initialize_global_process_group,
load_tokenizer_and_model,
prepare_inputs,
)
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
# remove the left padding in the prompt token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
def test_sglang_spmd():
assert torch.cuda.device_count() >= 2
initialize_global_process_group(spmd=True)
clean_torchelastic_env()
max_prompt_length = 16
max_response_length = 16
local_model_path = "Qwen/Qwen2.5-0.5B"
tokenizer, actor_model = load_tokenizer_and_model(local_model_path)
preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"]
input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)
hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)
tensor_parallel_size = 2
inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
tp_rank = inference_device_mesh_cpu["tp"].get_local_rank()
if tp_rank == 0:
llm = Engine(
model_path=local_model_path,
dtype="bfloat16",
mem_fraction_static=0.5,
enable_memory_saver=True,
tp_size=inference_device_mesh_cpu["tp"].size(),
)
input_ids = input_ids.cuda()
idx_list = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
for i in range(input_ids.shape[0]):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
sampling_params = dict(
n=1,
temperature=0,
top_p=1,
top_k=-1,
max_new_tokens=max_response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
skip_special_tokens=True,
spaces_between_special_tokens=True,
ignore_eos=False,
)
loop = asyncio.get_event_loop()
outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))
else:
outputs = None
[outputs] = broadcast_pyobj(
[outputs],
rank=inference_device_mesh_cpu["tp"].get_local_rank(),
src=inference_device_mesh_cpu["tp"].mesh[0].item(),
dist_group=inference_device_mesh_cpu["tp"].get_group(),
force_cpu_device=False,
)
sglang_response_tokens = [output["text"] for output in outputs]
print(f"sglang response: {sglang_response_tokens}")
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n"
print("SPMD Test Passed!")
torch.distributed.barrier()
torch.distributed.destroy_process_group()

View File

@ -37,7 +37,7 @@ def levenshtein(s1, s2):
return dp[m][n]
def are_lists_similar(a, b):
def are_lists_similar(a, b, threshold=10):
if len(a) != len(b):
print("The lists are of different lengths.")
return False
@ -49,7 +49,7 @@ def are_lists_similar(a, b):
total_diff += levenshtein(s1, s2)
percentage_difference = (total_diff / total_length) * 100
print(f"Total difference: {percentage_difference:.2f}%")
return percentage_difference <= 10
return percentage_difference <= threshold
def initialize_global_process_group(timeout_second=36000, spmd=False):

View File

@ -168,7 +168,7 @@ actor_rollout_ref:
n: 1
do_sample: False # default eager for validation
multi_turn:
enable: False # should set rollout.name to sglang_async if True
enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
max_turns: null # null for no limit (default max_length // 3)
tool_config_path: null # null for no tool
format: chatml # chatml, more formats will be supported in the future

View File

@ -140,7 +140,7 @@ actor_rollout_ref:
n: 1
do_sample: False # default eager for validation
multi_turn:
enable: False # should set rollout.name to sglang_async if True
enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
max_turns: null # null for no limit (default max_length // 3)
tool_config_path: null # null for no tool
format: chatml # chatml, more formats will be supported in the future

View File

@ -448,86 +448,39 @@ class ActorRolloutRefWorker(Worker):
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif rollout_name == "sglang":
if self.config.rollout.mode == "sync":
from verl.workers.rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
# the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
# we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
local_path = copy_to_local(self.config.model.path)
rollout = SGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
elif rollout_name in ["sglang", "sglang_async"]:
if rollout_name == "sglang_async":
warnings.warn(
"'sglang_async' has been deprecated and merged into 'sglang'. "
"Please use 'sglang' going forward.",
DeprecationWarning,
stacklevel=2,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
from verl.workers.rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
# the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
# we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif self.config.rollout.mode == "async":
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
local_path = copy_to_local(self.config.model.path)
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
rollout = AsyncSGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout._engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=None)
elif rollout_name == "sglang_async":
# TODO replace by rollout.mode == "async"
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
local_path = copy_to_local(self.config.model.path)
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
rollout = AsyncSGLangRollout(
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
rollout = SGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
rollout_sharding_manager = FSDPSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout._engine,
model_config=self.actor_model_config,
@ -535,7 +488,7 @@ class ActorRolloutRefWorker(Worker):
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=None)
log_gpu_memory_usage("After building sharding manager", logger=logger)
else:
raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")
@ -696,16 +649,8 @@ class ActorRolloutRefWorker(Worker):
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
if self.config.rollout.name == "sglang_async":
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0:
output = self.rollout.generate_sequences_with_tools(prompts=prompts)
else:
output = self.rollout.generate_sequences(prompts=prompts)
else:
output = self.rollout.generate_sequences(prompts=prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage("After rollout generation", logger=logger)
output = self.rollout_sharding_manager.postprocess_data(output)

View File

@ -18,6 +18,7 @@ The main entry point to run the PPO algorithm
import logging
import os
import time
import warnings
import torch
import torch.distributed
@ -258,7 +259,14 @@ class ActorRolloutRefWorker(MegatronWorker):
weight_converter=weight_converter,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif self.config.rollout.name == "sglang":
elif self.config.rollout.name in ["sglang", "sglang_async"]:
if self.config.rollout.name == "sglang_async":
warnings.warn(
"'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.",
DeprecationWarning,
stacklevel=2,
)
from verl.workers.rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
@ -276,45 +284,6 @@ class ActorRolloutRefWorker(MegatronWorker):
local_path = copy_to_local(self.config.model.path)
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None)
rollout = SGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=self.config.model.get("trust_remote_code", False),
)
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None)
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
sharding_manager = MegatronSGLangShardingManager(
actor_module=self.actor.actor_module,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
transformer_config=self.tf_config,
layer_name_mapping=layer_name_mapping,
weight_converter=weight_converter,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif self.config.rollout.name == "sglang_async":
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
# However, due to verl's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.megatron_sglang import MegatronAsyncSGLangShardingManager
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
rollout_device_mesh = init_device_mesh("cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp"))
local_path = copy_to_local(self.config.model.path)
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None)
rollout = AsyncSGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
@ -327,7 +296,7 @@ class ActorRolloutRefWorker(MegatronWorker):
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
sharding_manager = MegatronAsyncSGLangShardingManager(
sharding_manager = MegatronSGLangShardingManager(
actor_module=self.actor.actor_module,
inference_engine=rollout._engine,
model_config=self.actor_model_config,
@ -491,16 +460,7 @@ class ActorRolloutRefWorker(MegatronWorker):
log_gpu_memory_usage("After entering sharding manager", logger=logger)
prompts = self.sharding_manager.preprocess_data(prompts)
# output = self.rollout.generate_sequences(prompts=prompts)
if self.config.rollout.name == "sglang_async":
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0:
output = self.rollout.generate_sequences_with_tools(prompts=prompts)
else:
output = self.rollout.generate_sequences(prompts=prompts)
else:
output = self.rollout.generate_sequences(prompts=prompts)
output = self.rollout.generate_sequences(prompts=prompts)
output = self.sharding_manager.postprocess_data(output)
output = output.to("cpu")

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .async_sglang_rollout import AsyncSGLangRollout
from .sglang_rollout import SGLangRollout
__all__ = ["AsyncSGLangRollout", "SGLangRollout"]
__all__ = ["SGLangRollout"]

View File

@ -1,876 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import os
import time
from contextlib import contextmanager
from copy import deepcopy
from json import JSONDecodeError
from typing import TYPE_CHECKING, Union
from uuid import uuid4
import numpy as np
import torch
import torch.distributed as dist
from omegaconf import DictConfig
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.openai_api.protocol import Tool
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.utils import get_ip, get_open_port
from tensordict import TensorDict
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.third_party.sglang import parallel_state as sglang_ps
from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall
from verl.utils.debug import GPUMemoryLogger
from verl.utils.model import compute_position_id_with_mask
from verl.utils.net_utils import is_ipv6
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.schemas import (
AsyncRolloutRequest,
AsyncRolloutRequestStateEnum,
FinishReasonTypeEnum,
Message,
)
from verl.workers.rollout.sglang_rollout.sglang_rollout import _post_process_outputs, _pre_process_inputs
from verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj
if TYPE_CHECKING:
from torch import nn
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str:
for parser_type, parser_cls in FunctionCallParser.ToolCallParserEnum.items():
parser = parser_cls()
if parser.bot_token in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab()):
return parser_type
else:
raise ValueError(f"No tool call parser found for tokenizer {tokenizer}")
class AsyncSGLangRollout(BaseRollout):
def __init__(
self,
actor_module: nn.Module | str,
config: DictConfig,
tokenizer,
model_hf_config,
port=None,
trust_remote_code: bool = False,
device_mesh: DeviceMesh | None = None,
**kwargs,
):
"""A SGLang rollout. It requires the module is supported by the SGLang.
Args:
actor_module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
model_hf_config: the huggingface config to initiallize the generating model in SGLang
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
"""
super().__init__()
self.config = config
self._tool_schemas, self._tool_map, self._tool_call_parser_type, self._sgl_tools, self._function_call_parser = self._initialize_tools(config, tokenizer)
assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"
logger.info(f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: {self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: {self._function_call_parser}")
self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs)
self._verify_config(model_hf_config=model_hf_config)
# initialize the inference engine
self._init_inference_engine(trust_remote_code, actor_module, port)
self._init_sampling_params(**kwargs)
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
def _init_distributed_env(self, device_mesh_cpu, **kwargs):
self._device_mesh_cpu = device_mesh_cpu
os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")
self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
assert self.tensor_parallel_size <= dist.get_world_size(), "tensor parallel size should be less than or equal to the world size"
self.train_tp = kwargs.get("train_tp", None)
if self.train_tp is not None:
# deployed with megatron
os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
train_tp = kwargs.get("train_tp", None)
num_tp_per_train_tp = train_tp // self.tensor_parallel_size
sglang_ps.initialize_parallel_state(
tensor_model_parallel_size=self.tensor_parallel_size,
num_tp_per_train_tp=num_tp_per_train_tp,
)
tp_size = self.tensor_parallel_size
world_size = int(os.getenv("WORLD_SIZE", "-1"))
# init device mesh
if self._device_mesh_cpu is None:
device_mesh_kwargs = dict(
mesh_shape=(world_size // tp_size, tp_size, 1),
mesh_dim_names=["dp", "tp", "pp"],
)
self._device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
self._rank = self._device_mesh_cpu.get_rank()
self._tp_rank = self._device_mesh_cpu["tp"].get_local_rank()
self._tp_size = self._device_mesh_cpu["tp"].size()
if self._rank == 0:
logger.info(f"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}")
# get tp_rank of this process in this tp group
visible_devices = [None] * self._device_mesh_cpu.size(1)
torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp"))
self.visible_devices_set = set(",".join(visible_devices).split(","))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(self.visible_devices_set)))
def _verify_config(self, model_hf_config):
if not self.config.get("max_model_len", None):
self.config.max_model_len = self.config.prompt_length + self.config.response_length
assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length):
{self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}"""
assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length"
# currently max_turns stand for max number of tool calls
if self.config.multi_turn.max_turns is None:
self.config.multi_turn.max_turns = self.config.max_model_len // 3
def _init_inference_engine(self, trust_remote_code, actor_module, port):
# initialize the inference engine
nnodes = -(-self._tp_size // len(self.visible_devices_set))
if nnodes > 1:
ip = get_ip()
port = get_open_port() if port is None else port
[ip, port] = broadcast_pyobj(
[ip, port],
rank=self._rank,
dist_group=self._device_mesh_cpu.get_group("tp"),
src=self._device_mesh_cpu["tp"].mesh[0].item(),
force_cpu_device=False,
)
dist_init_addr = f"[{ip}]:{port}" if is_ipv6(ip) else f"{ip}:{port}"
else:
dist_init_addr = None
load_format = "dummy" if self.config.load_format.startswith("dummy") else self.config.load_format
tp_size_per_node = self._tp_size // nnodes
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
if first_rank_in_node:
rank = dist.get_rank()
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(
model_path=actor_module,
dtype=self.config.dtype,
mem_fraction_static=self.config.gpu_memory_utilization,
enable_memory_saver=True,
base_gpu_id=0,
gpu_id_step=1,
tp_size=self._tp_size,
node_rank=node_rank,
load_format=load_format,
dist_init_addr=dist_init_addr,
nnodes=nnodes,
trust_remote_code=trust_remote_code,
# NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new
# when random.seed is being set during training
port=30000 + rank,
# NOTE(Chenyang): if you want to debug the SGLang engine output
# please set the following parameters
# Otherwise, it will make the engine run too slow
# log_level="INFO",
# log_requests=True,
# log_requests_level=2,
# max_running_requests=1,
)
else:
self._engine = None
self.sharding_manager = None
# offload
if self._tp_rank == 0:
self._engine.release_memory_occupation()
self.is_sleep = True
def _init_sampling_params(self, **kwargs):
kwargs = dict(
n=1,
max_new_tokens=self.config.response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
)
# supporting adding any sampling params from the config file
for k in self.config.keys():
if hasattr(SamplingParams(), str(k)):
kwargs[k] = self.config.get(k)
self.sampling_params = kwargs
def _initialize_tools(self, config, tokenizer):
"""Initialize tools from configuration.
Args:
config: Configuration object containing tool settings
tokenizer: Tokenizer instance for tool call parsing
Returns:
tuple: (tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser)
"""
if config.multi_turn.tool_config_path is None:
return [], {}, None, [], None
import importlib.util
import sys
from omegaconf import OmegaConf
from verl.tools.schemas import OpenAIFunctionToolSchema
def initialize_tools_from_config(tools_config) -> list:
tool_list = []
for tool_config in tools_config.tools:
cls_name = tool_config.class_name
module_name, class_name = cls_name.rsplit(".", 1)
if module_name not in sys.modules:
spec = importlib.util.find_spec(module_name)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
else:
module = sys.modules[module_name]
tool_cls = getattr(module, class_name)
tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True)
tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict)
tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema)
tool_list.append(tool)
return tool_list
tools_config_file = config.multi_turn.tool_config_path
tools_config = OmegaConf.load(tools_config_file)
tool_list = initialize_tools_from_config(tools_config)
logger.info(f"Initialize tools from configuration.: tool_list: {tool_list}")
tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list]
tool_map = {tool.name: tool for tool in tool_list}
tool_call_parser_type = get_tool_call_parser_type(tokenizer)
sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas]
function_call_parser = FunctionCallParser(
sgl_tools,
tool_call_parser_type,
)
return tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser
@contextmanager
def update_sampling_params(self, **kwargs):
# update sampling params
old_sampling_params_args = {}
if kwargs:
for key, value in kwargs.items():
if key in self.sampling_params:
old_value = self.sampling_params[key]
old_sampling_params_args[key] = old_value
self.sampling_params[key] = value
yield
# roll back to previous sampling params
# if len(old_sampling_params_args):
for key, value in old_sampling_params_args.items():
self.sampling_params[key] = value
@GPUMemoryLogger(role="sglang async rollout", logger=logger)
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# if self.config.free_cache_engine:
idx = prompts.batch["input_ids"] # (bs, prompt_length)
# left-padded attention_mask
attention_mask = prompts.batch["attention_mask"]
position_ids = prompts.batch["position_ids"]
# used to construct attention_mask
eos_token_id = prompts.meta_info["eos_token_id"]
batch_size = idx.size(0)
# Extract non-tensor data
non_tensor_batch = prompts.non_tensor_batch
if "raw_prompt_ids" not in non_tensor_batch:
non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)
if "multi_modal_data" in non_tensor_batch:
sglang_inputs = []
for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")):
sglang_inputs.append(
{
"prompt_token_ids": raw_prompt_ids,
"multi_modal_data": multi_modal_data,
"image_data": multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None,
}
)
else:
sglang_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")]
# Ensure token IDs are lists
for input_data in sglang_inputs:
if isinstance(input_data["prompt_token_ids"], np.ndarray):
input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist()
elif not isinstance(input_data["prompt_token_ids"], list):
raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}")
# Extract token IDs and image data for SGLang Engine
idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs]
image_list = [input_data.get("image_data", None) for input_data in sglang_inputs]
do_sample = prompts.meta_info.get("do_sample", True)
is_validate = prompts.meta_info.get("validate", False)
if not do_sample:
kwargs = dict(
n=1,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
temperature=0,
top_p=1,
top_k=-1,
ignore_eos=False,
min_new_tokens=0,
max_new_tokens=self.config.response_length,
skip_special_tokens=True,
spaces_between_special_tokens=True,
)
elif is_validate:
kwargs = dict(
top_k=self.config.val_kwargs.top_k,
top_p=self.config.val_kwargs.top_p,
temperature=self.config.val_kwargs.temperature,
n=1, # if validate, already repeat in ray_trainer
)
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
# print(f"{self.sampling_params=}")
if self._tp_rank == 0:
loop = asyncio.get_event_loop()
output = loop.run_until_complete(
self._engine.async_generate(
prompt=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
return_logprob=True,
input_ids=idx_list,
image_data=image_list,
)
)
else:
output = None
# Most naive implementation, can extract tensor and send via gloo if too slow
[output] = broadcast_pyobj(
data=[output],
rank=self._rank,
dist_group=self._device_mesh_cpu["tp"].get_group(),
src=self._device_mesh_cpu["tp"].mesh[0].item(),
force_cpu_device=False,
)
out = _post_process_outputs(self.tokenizer, output)
response = out[0].to(idx.device)
rollout_log_probs = out[1].to(idx.device)
if response.shape[1] < self.config.response_length:
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id)
# utilize current sampling params
if self.sampling_params.get("n", 1) > 1 and do_sample:
idx = idx.repeat_interleave(self.sampling_params["n"], dim=0)
attention_mask = attention_mask.repeat_interleave(self.sampling_params["n"], dim=0)
position_ids = position_ids.repeat_interleave(self.sampling_params["n"], dim=0)
batch_size = batch_size * self.sampling_params["n"]
_non_tensor_batch = {}
for key, val in non_tensor_batch.items():
_non_tensor_batch[key] = np.repeat(val, self.sampling_params["n"], axis=0)
else:
_non_tensor_batch = non_tensor_batch
seq = torch.cat([idx, response], dim=-1)
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
{
"prompts": idx,
"responses": response,
"input_ids": seq, # here input_ids become the whole sentences
"rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor
"attention_mask": attention_mask,
"position_ids": position_ids,
},
batch_size=batch_size,
)
# free cache engine
if self.config.free_cache_engine and self._engine is not None:
self._engine.flush_cache()
return DataProto(batch=batch, non_tensor_batch=_non_tensor_batch)
async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bool = True, is_validate: bool = False, **kwargs) -> AsyncRolloutRequest:
assert self._tp_rank == 0, "only the master process can call this function"
_req = deepcopy(req)
finish_reason_type = None
output = None
current_turns = 0
while current_turns < self.config.multi_turn.max_turns:
if _req.state == AsyncRolloutRequestStateEnum.PENDING:
await self._handle_pending_state(_req)
_req.state = AsyncRolloutRequestStateEnum.RUNNING
elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:
if _req.messages[-1].tool_calls is not None:
parsed_tool_calls = _req.messages[-1].tool_calls
tool_call_results = await asyncio.gather(
*[
self._tool_map[tool_call.function.name].execute(
_req.request_id,
tool_call.function.arguments,
**_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}),
)
for tool_call in parsed_tool_calls
]
)
for i, (tool_call, (resp, reward, metrics)) in enumerate(zip(parsed_tool_calls, tool_call_results)):
_req.add_tool_response_message(self.tokenizer, resp, (i == len(parsed_tool_calls) - 1), format=self.config.multi_turn.format)
_req.update_metrics(metrics, tool_call.function.name)
if len(_req.input_ids) >= self.config.max_model_len:
break
if len(_req.input_ids) >= self.config.max_model_len:
finish_reason_type = FinishReasonTypeEnum.STOP
break
_req.state = AsyncRolloutRequestStateEnum.RUNNING
else:
raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}")
elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:
output = await self._handle_engine_call(_req, do_sample, is_validate, **kwargs)
content = output["text"]
finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"])
current_turns += 1
if finish_reason_type == FinishReasonTypeEnum.LENGTH:
_req.add_assistant_message(self.tokenizer, content, already_over_long=True, format=self.config.multi_turn.format)
break
else:
if self._function_call_parser and self._function_call_parser.has_tool_call(content):
finish_reason_type = FinishReasonTypeEnum.TOOL_CALL
_req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING
try:
normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)
except JSONDecodeError:
normed_content = content
tool_calls = []
except AttributeError:
normed_content = content
tool_calls = []
parsed_tool_calls = []
for tool_call in tool_calls:
function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(OpenAIFunctionParsedSchema(name=tool_call.name, arguments=tool_call.parameters))
# Drop the tool call if its arguments has decode error
if has_decode_error:
continue
parsed_tool_calls.append(
OpenAIFunctionToolCall(
id=str(tool_call.tool_index),
function=function,
)
)
if len(parsed_tool_calls) > 0:
_req.add_assistant_message(
self.tokenizer,
normed_content,
tool_calls=parsed_tool_calls,
format=self.config.multi_turn.format,
)
else:
_req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format)
finish_reason_type = FinishReasonTypeEnum.STOP
_req.state = AsyncRolloutRequestStateEnum.COMPLETED
break
else:
_req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format)
break
if current_turns >= self.config.multi_turn.max_turns:
finish_reason_type = FinishReasonTypeEnum.STOP
# Calculate the reward for each tool
async def calc_reward_and_release_fn(name: str, tool: BaseTool):
reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {}))
await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {}))
return name, reward
tool_reward_tasks = []
for name in _req.tools_kwargs.keys():
tool = self._tool_map[name]
tool_reward_tasks.append(calc_reward_and_release_fn(name, tool))
tool_reward_scores = await asyncio.gather(*tool_reward_tasks)
tool_reward_scores = dict(tool_reward_scores)
_req.finalize(self.tokenizer, tool_reward_scores, finish_reason_type)
return _req
async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs) -> dict:
generation_prompt_ids = _req.get_generation_prompt(self.tokenizer)
max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)
if not do_sample:
kwargs = dict(
n=1,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
temperature=0,
top_p=1,
top_k=-1,
ignore_eos=False,
min_new_tokens=0,
max_new_tokens=self.config.response_length,
skip_special_tokens=True,
spaces_between_special_tokens=True,
)
elif is_validate:
# TODO: try **
kwargs = {
"top_k": self.config.val_kwargs.top_k,
"top_p": self.config.val_kwargs.top_p,
"temperature": self.config.val_kwargs.temperature,
"n": 1, # if validate, already repeat in ray_trainer
}
kwargs["max_new_tokens"] = max_new_tokens
if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess
kwargs["n"] = 1
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
output = await self._engine.async_generate(
input_ids=generation_prompt_ids,
sampling_params=self.sampling_params,
return_logprob=False,
)
return output
async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest:
if _req.tools is not None:
tool_creation_coroutines = []
for tool_schema in _req.tools:
tool = self._tool_map[tool_schema.function.name]
create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {})
tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs))
await asyncio.gather(*tool_creation_coroutines)
@GPUMemoryLogger(role="sglang async rollout", logger=logger)
@torch.no_grad()
def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto:
# Async rollout with tools support
do_sample = prompts.meta_info.get("do_sample", True)
is_validate = prompts.meta_info.get("validate", False)
tgt_device = prompts.batch["input_ids"].device
if self._tp_rank == 0:
req_list = self._preprocess_prompt_to_async_rollout_requests(
prompts,
n=1 if is_validate else self.config.n,
)
loop = asyncio.get_event_loop()
output_req_list = loop.run_until_complete(
asyncio.gather(
*[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],
)
)
sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))
else:
sorted_output_req_list = None
[sorted_output_req_list] = broadcast_pyobj(
data=[sorted_output_req_list],
rank=self._rank,
dist_group=self._device_mesh_cpu["tp"].get_group(),
src=self._device_mesh_cpu["tp"].mesh[0].item(),
force_cpu_device=False,
)
# Construct the batch data
prompt_ids, response_ids = [], []
prompt_attention_mask, response_attention_mask = [], []
prompt_position_ids, response_position_ids = [], []
prompt_loss_mask, response_loss_mask = [], []
messages = []
reward_scores = []
for req in sorted_output_req_list:
assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed"
assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of
{len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}"""
error_message_lines = [
f"""Request {req.request_id} has input_ids length {len(req.input_ids)}
greater than max_model_len {self.config.max_model_len}""",
f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}",
f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}",
f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}",
f"Messages: {req.messages}",
f"Max model length: {req.max_model_len}",
]
error_message = "\n".join(error_message_lines)
assert len(req.input_ids) <= self.config.max_model_len, error_message
prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device))
response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device))
if len(req.response_ids) > self.config.response_length:
logger.warning(
f"""{req.request_id=} has response_ids length {len(req.response_ids)}
greater than max_response_len {self.config.response_length},\n{req=}"""
)
prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device))
response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device))
prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device))
response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device))
prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device))
response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device))
messages.append({"messages": req.messages})
reward_scores.append(req.reward_scores)
prompt_ids = pad_sequence(prompt_ids, batch_first=True, padding_value=self.pad_token_id, padding_side="left")
if prompt_ids.shape[1] < self.config.prompt_length:
prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True)
response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id)
if response_ids.shape[1] < self.config.response_length:
response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id)
prompt_attention_mask = pad_sequence(prompt_attention_mask, batch_first=True, padding_value=0, padding_side="left")
if prompt_attention_mask.shape[1] < self.config.prompt_length:
prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True)
response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0)
if response_attention_mask.shape[1] < self.config.response_length:
response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0)
prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left")
if prompt_position_ids.shape[1] < self.config.prompt_length:
prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True)
response_length = response_ids.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1)
response_position_ids = prompt_position_ids[:, -1:] + delta_position_id
prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left")
if prompt_loss_mask.shape[1] < self.config.prompt_length:
prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True)
response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)
if response_loss_mask.shape[1] < self.config.response_length:
response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)
input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1)
loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1)
# Construct the batch data
batch = TensorDict(
{
"prompts": prompt_ids,
"responses": response_ids,
"input_ids": input_ids, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"position_ids": position_ids,
"loss_mask": loss_mask,
},
batch_size=len(sorted_output_req_list),
)
# free cache engine
if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0:
self._engine.flush_cache()
return DataProto(batch=batch, non_tensor_batch={"messages": np.array(messages), "reward_scores": np.array(reward_scores)})
def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]:
assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages"
req_list = []
for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]):
for rollout_offset in range(n):
if self._tool_schemas:
_tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx]
_tool_schemas = []
for k in _tools_kwargs.keys():
_tool_schemas.append(self._tool_map[k].get_openai_tool_schema())
prompt_with_chat_template = self.tokenizer.apply_chat_template(
conversation=raw_prompt,
tools=[tool.model_dump() for tool in _tool_schemas],
add_generation_prompt=True,
tokenize=False,
return_tensors="pt",
)
input_data = self.tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False)
_input_ids = input_data["input_ids"][0].tolist()
_attention_mask = input_data["attention_mask"][0].tolist()
_position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist()
if len(_input_ids) > self.config.prompt_length:
logger.warning(
"Prompt {} has length {} greater than max_prompt_len {}",
data_idx,
len(_input_ids),
self.config.prompt_length,
)
_input_ids = _input_ids[: self.config.prompt_length]
_attention_mask = _attention_mask[: self.config.prompt_length]
_position_ids = _position_ids[: self.config.prompt_length]
else:
_input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx])
_attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx])
_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist()
_tool_schemas = []
_tools_kwargs = {}
req = AsyncRolloutRequest(
batch_data_id=data_idx,
rollout_offset=rollout_offset,
request_id=str(uuid4()),
state=AsyncRolloutRequestStateEnum.PENDING,
messages=[Message.model_validate(msg) for msg in raw_prompt],
tools=_tool_schemas,
tools_kwargs=_tools_kwargs,
input_ids=_input_ids,
prompt_ids=_input_ids,
response_ids=[],
attention_mask=_attention_mask,
prompt_attention_mask=_attention_mask,
response_attention_mask=[],
position_ids=_position_ids,
prompt_position_ids=_position_ids,
response_position_ids=[],
loss_mask=[0] * len(_input_ids),
prompt_loss_mask=[0] * len(_input_ids),
response_loss_mask=[],
reward_scores={},
max_response_len=self.config.response_length,
max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),
)
error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}"
assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message
req_list.append(req)
return req_list
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
if method == "chat_completion":
json_request = args[0]
formatted_messages = []
for msg in json_request["messages"]:
role = msg.get("role", "user")
content = msg.get("content", "")
formatted_messages.append(f"{role}: {content}")
prompt_str = "\n".join(formatted_messages)
sampling_params_dict = {
"n": json_request.get("n", 1),
"max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length),
"temperature": json_request.get("temperature", 1.0),
"top_p": json_request.get("top_p", 1.0),
}
output = None
if self._tp_rank == 0:
loop = asyncio.get_event_loop()
output = loop.run_until_complete(
self._engine.async_generate(
prompt=prompt_str,
sampling_params=sampling_params_dict,
return_logprob=True,
)
)
output = broadcast_pyobj(
data=[output],
rank=self._rank,
dist_group=self._device_mesh_cpu["tp"].get_group(),
src=self._device_mesh_cpu["tp"].mesh[0].item(),
force_cpu_device=False,
)
# only return value from master rank
if self._tp_rank != 0:
return None
# build openai chat completion format
choices = []
id = None
for i, content in enumerate(output):
choices.append(
{
"index": i,
"message": {
"role": "assistant",
"content": content["text"],
},
"finish_reason": content["meta_info"]["finish_reason"]["type"],
}
)
id = content["meta_info"]["id"]
return {
"id": "chatcmpl-" + id,
"object": "chat.completion",
"created": int(time.time()),
"model": json_request.get("model", "sglang_model"),
"choices": choices,
}
else:
raise ValueError(f"not supported method : {method}")
# this function is left for uniform train-inference resharding
def resume(self):
if not self.is_sleep:
return
self.sharding_manager.__enter__() # pylint: disable=C2801
self.is_sleep = False
# this function is left for uniform train-inference resharding
def offload(self):
if self.is_sleep:
return
self.sharding_manager.__exit__(None, None, None)
self.is_sleep = True

View File

@ -1,3 +1,4 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");

File diff suppressed because it is too large Load Diff

View File

@ -34,7 +34,9 @@ def broadcast_pyobj(
The `rank` here refer to the source rank on global process group (regardless
of dist_group argument).
"""
device = torch.device("cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu")
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
)
if rank == src:
if len(data) == 0:
@ -44,7 +46,9 @@ def broadcast_pyobj(
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device)
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
).to(device)
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
dist.broadcast(tensor_size, src=src, group=dist_group)

View File

@ -1,18 +1,5 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -29,12 +16,10 @@
import logging
import os
from typing import Union
import torch
import torch.distributed as dist
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.verl_engine import VerlEngine
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.utils import MultiprocessingSerializer
from torch.distributed.device_mesh import DeviceMesh
@ -66,7 +51,7 @@ class FSDPSGLangShardingManager(BaseShardingManager):
def __init__(
self,
module: FSDP,
inference_engine: Union[VerlEngine, Engine],
inference_engine: Engine,
model_config,
full_params: bool = False,
device_mesh: DeviceMesh = None,
@ -144,44 +129,6 @@ class FSDPSGLangShardingManager(BaseShardingManager):
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
def update_weights(self, params):
self.inference_engine.resume_memory_occupation()
self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None)
def release_memory(self):
self.inference_engine.release_memory_occupation()
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data
# TODO: Current impl doesn't consider FSDP with torch micro-dp
group = self.device_mesh["infer_tp"].get_group()
all_gather_data_proto(data=data, process_group=group)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size == 1:
return data
return data.chunk(chunks=self.tp_size)[self.tp_rank]
class FSDPAsyncSGLangShardingManager(FSDPSGLangShardingManager):
def __init__(
self,
module: FSDP,
inference_engine: Engine,
model_config,
full_params: bool = False,
device_mesh: DeviceMesh = None,
offload_param: bool = False,
):
super().__init__(module, inference_engine, model_config, full_params, device_mesh, offload_param)
def update_weights(self, params):
if self.device_mesh["infer_tp"].get_local_rank() == 0:
self.inference_engine.resume_memory_occupation()
@ -218,3 +165,21 @@ class FSDPAsyncSGLangShardingManager(FSDPSGLangShardingManager):
def release_memory(self):
if self.device_mesh["infer_tp"].get_local_rank() == 0:
self.inference_engine.release_memory_occupation()
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data
# TODO: Current impl doesn't consider FSDP with torch micro-dp
group = self.device_mesh["infer_tp"].get_group()
all_gather_data_proto(data=data, process_group=group)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size == 1:
return data
return data.chunk(chunks=self.tp_size)[self.tp_rank]

View File

@ -22,7 +22,6 @@ import os
import torch
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.verl_engine import VerlEngine
from torch import nn
from torch.distributed.device_mesh import DeviceMesh
@ -50,7 +49,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
def __init__(
self,
actor_module: nn.ModuleList,
inference_engine: VerlEngine,
inference_engine: Engine,
model_config,
transformer_config,
layer_name_mapping,
@ -113,50 +112,6 @@ class MegatronSGLangShardingManager(BaseShardingManager):
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
def update_weights(self, params):
self.inference_engine.resume_memory_occupation()
self.inference_engine.update_weights_from_tensor(params, load_format=None)
def release_memory(self):
self.inference_engine.release_memory_occupation()
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
def preprocess_data(self, data: DataProto) -> DataProto:
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1:
return data
all_gather_data_proto(data, self.device_mesh["tp"].get_group())
return data
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
def postprocess_data(self, data: DataProto) -> DataProto:
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1:
return data
return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()]
class MegatronAsyncSGLangShardingManager(MegatronSGLangShardingManager):
def __init__(
self,
actor_module: nn.ModuleList,
inference_engine: Engine,
model_config,
transformer_config,
layer_name_mapping,
weight_converter,
device_mesh: DeviceMesh = None,
):
super().__init__(
actor_module,
inference_engine,
model_config,
transformer_config,
layer_name_mapping,
weight_converter,
device_mesh,
)
def update_weights(self, params):
if self.device_mesh["tp"].get_local_rank() == 0:
self.inference_engine.resume_memory_occupation()
@ -184,3 +139,18 @@ class MegatronAsyncSGLangShardingManager(MegatronSGLangShardingManager):
def release_memory(self):
if self.device_mesh["tp"].get_local_rank() == 0:
self.inference_engine.release_memory_occupation()
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
def preprocess_data(self, data: DataProto) -> DataProto:
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1:
return data
all_gather_data_proto(data, self.device_mesh["tp"].get_group())
return data
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
def postprocess_data(self, data: DataProto) -> DataProto:
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1:
return data
return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()]