mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[sglang] refactor: Unify async rollout under SGLangRollout, and support sglang==0.4.6.post5 (#1717)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? - Unify the functionality of SGLangRollout and AsyncSGLangRollout, remove original SGLangRollout and rename AsyncSGLangRollout to SGLangRollout. - Make trivial changes due to modification in sglang==0.4.6.post5. ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add CI test(s) if necessary. --------- Co-authored-by: zyzshishui <@qq.com> Co-authored-by: Xiang Long <mindsculptor@yeah.net> Co-authored-by: ocss884 <ocss.lin@gmail.com> Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com> Co-authored-by: H <linhaibin.eric@gmail.com>
This commit is contained in:
41
.github/workflows/e2e_ppo_trainer.yml
vendored
41
.github/workflows/e2e_ppo_trainer.yml
vendored
@ -204,7 +204,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -222,7 +222,7 @@ jobs:
|
||||
ray stop --force
|
||||
ENGINE=sglang bash tests/e2e/ppo_trainer/run_function_reward.sh
|
||||
|
||||
e2e_ppo_trainer_sglang_async:
|
||||
e2e_ppo_trainer_sglang_multiturn_with_tool:
|
||||
runs-on: [L20x8]
|
||||
needs: pre_commit_for_ppo
|
||||
timeout-minutes: 40 # Increase this timeout value as needed
|
||||
@ -233,36 +233,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install the current repository
|
||||
run: |
|
||||
pip3 install -e .[test,gpu,sglang] --no-deps
|
||||
- name: Prepare gsm8k dataset
|
||||
run: |
|
||||
ray stop --force
|
||||
python3 examples/data_preprocess/gsm8k.py
|
||||
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async
|
||||
run: |
|
||||
ray stop --force
|
||||
ENGINE=sglang_async bash tests/e2e/ppo_trainer/run_function_reward.sh
|
||||
|
||||
e2e_ppo_trainer_sglang_async_with_tool:
|
||||
runs-on: [L20x8]
|
||||
needs: pre_commit_for_ppo
|
||||
timeout-minutes: 40 # Increase this timeout value as needed
|
||||
env:
|
||||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
|
||||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
|
||||
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -275,7 +246,7 @@ jobs:
|
||||
run: |
|
||||
ray stop --force
|
||||
python3 examples/data_preprocess/gsm8k_multiturn_w_tool.py --local_dir $HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed
|
||||
- name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async
|
||||
- name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang
|
||||
run: |
|
||||
ray stop --force
|
||||
bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh
|
||||
@ -295,7 +266,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=50g # Visual dataloader requires large memory
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -367,7 +338,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=50g # Visual dataloader requires large memory
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
10
.github/workflows/e2e_ppo_trainer_megatron.yml
vendored
10
.github/workflows/e2e_ppo_trainer_megatron.yml
vendored
@ -50,7 +50,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -92,7 +92,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -134,7 +134,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -167,7 +167,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -206,7 +206,7 @@ jobs:
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
container:
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
8
.github/workflows/sgl.yml
vendored
8
.github/workflows/sgl.yml
vendored
@ -56,7 +56,7 @@ jobs:
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True"
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4
|
||||
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -73,11 +73,7 @@ jobs:
|
||||
- name: Test the latest SGLang
|
||||
run: |
|
||||
cd tests/workers/rollout
|
||||
torchrun --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_sglang_spmd.py
|
||||
- name: Test the latest SGLang async
|
||||
run: |
|
||||
cd tests/workers/rollout
|
||||
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_spmd.py
|
||||
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_spmd.py
|
||||
- name: Test the latest SGLang Rollout async with tool
|
||||
run: |
|
||||
cd tests/workers/rollout
|
||||
|
@ -6,7 +6,7 @@
|
||||
# Support - Traing: fsdp; Inference: vllm
|
||||
# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4
|
||||
# Support - Traing: fsdp; Inference: vllm, sglang
|
||||
FROM lmsysorg/sglang:v0.4.6.post4-rocm630
|
||||
FROM lmsysorg/sglang:v0.4.6.post5-rocm630
|
||||
|
||||
# Set working directory
|
||||
# WORKDIR $PWD/app
|
||||
|
@ -36,8 +36,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
# Install sglang-0.4.6.post4 and torch-memory-saver
|
||||
RUN pip install "sglang[all]==0.4.6.post4" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
|
||||
# Install sglang-0.4.6.post5 and torch-memory-saver
|
||||
RUN pip uninstall -y cuda-python && pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
|
||||
|
||||
# Install torch-2.6.0
|
||||
RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \
|
||||
|
@ -56,12 +56,12 @@ RUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvi
|
||||
update-alternatives --set cuda /usr/local/cuda-12.4 && \
|
||||
rm -rf /usr/local/cuda-12.6
|
||||
|
||||
# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post4
|
||||
# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post5
|
||||
# torch-2.6.0+cu124: cxx11abi=False
|
||||
# torch-2.6.0+cu126: cxx11abi=True
|
||||
# see https://github.com/flashinfer-ai/flashinfer/issues/911
|
||||
# Install sglang-0.4.6.post1 and torch-memory-saver
|
||||
RUN pip install "sglang[all]==0.4.6.post1" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
|
||||
RUN pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
|
||||
|
||||
RUN pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata
|
||||
|
||||
|
@ -22,7 +22,7 @@ docker/Dockerfile.rocm
|
||||
# Support - Traing: fsdp; Inference: vllm
|
||||
# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4
|
||||
# Support - Traing: fsdp; Inference: vllm, sglang
|
||||
FROM lmsysorg/sglang:v0.4.6.post4-rocm630
|
||||
FROM lmsysorg/sglang:v0.4.6.post5-rocm630
|
||||
|
||||
# Set working directory
|
||||
# WORKDIR $PWD/app
|
||||
|
@ -11,9 +11,9 @@ To enable multi-turn rollout, make sure to configure the following fields in you
|
||||
actor_rollout_ref:
|
||||
rollout:
|
||||
multi_turn: True
|
||||
name: "sglang_async"
|
||||
name: "sglang"
|
||||
|
||||
These configuration activates the sglang_async engine for multi-turn interaction during rollout.
|
||||
These configuration activates the sglang engine for multi-turn interaction during rollout.
|
||||
|
||||
Custom Tool Configuration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -42,7 +42,7 @@ For vLLM with Megatron or FSDP, please use the stable version of image ``whatcan
|
||||
|
||||
For latest vLLM with FSDP, please refer to ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``.
|
||||
|
||||
For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4`` which is provided by SGLang RL Group.
|
||||
For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.
|
||||
|
||||
See files under ``docker/`` for NGC-based image or if you want to build your own.
|
||||
|
||||
@ -79,7 +79,7 @@ See files under ``docker/`` for NGC-based image or if you want to build your own
|
||||
- **Flash Attenttion**: 2.7.4.post1
|
||||
- **Flash Infer**: 0.2.2.post1
|
||||
- **vLLM**: 0.8.5
|
||||
- **SGLang**: 0.4.6.post4
|
||||
- **SGLang**: 0.4.6.post5
|
||||
- **Megatron-LM**: core_v0.12.0
|
||||
- **TransformerEngine**: 2.3
|
||||
- **Ray**: 2.44.1
|
||||
|
@ -21,7 +21,7 @@ Please always follow the following command to install SGLang with verl.
|
||||
.. code-block:: bash
|
||||
|
||||
pip install --upgrade pip
|
||||
# Currently 0.4.6.post4, subject to updates at any time, please refer to the latest version specified in `setup.py`
|
||||
# Currently 0.4.6.post5, subject to updates at any time, please refer to the latest version specified in `setup.py`
|
||||
pip install -e ".[sglang]"
|
||||
|
||||
You can check the following dependencies are in your environment:
|
||||
@ -31,8 +31,8 @@ You can check the following dependencies are in your environment:
|
||||
- **PyTorch**: 2.6.0+cu124
|
||||
- **CUDA**: 12.4
|
||||
- **flashinfer-python**: 0.2.5+cu124torch2.6
|
||||
- **sgLang**: 0.4.6.post4
|
||||
- **sgl-kernel**: 0.1.2.post1
|
||||
- **sgLang**: 0.4.6.post5
|
||||
- **sgl-kernel**: 0.1.4
|
||||
|
||||
Using SGLang as the Inference Backend for PPO Training on a Single Machine
|
||||
-------------------------------------------------------------------------
|
||||
@ -87,7 +87,7 @@ Why export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK?
|
||||
|
||||
1. ``verl`` initializes a ``SGLangRollout`` module during rollout, which is used to evaluate/generate samples.
|
||||
|
||||
2. ``SGLangRollout`` will initialize ``VerlEngine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP).
|
||||
2. ``SGLangRollout`` will initialize ``Engine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP).
|
||||
|
||||
3. ``DeviceMesh.init()`` internally checks the free GPU memory of all participating devices. If the difference is too large (more than ~10%), it directly reports an error to avoid initialization failures or deadlocks.
|
||||
|
||||
@ -111,7 +111,7 @@ Early workers already use up GPU memory → late workers still have empty memory
|
||||
|
||||
**3. SGLang's TP init uses "all-device broadcast", but there's no uniform release timing**
|
||||
|
||||
Although ``SGLangRollout`` may only involve subset of GPUs, its ``VerlEngine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so:
|
||||
Although ``SGLangRollout`` may only involve subset of GPUs, its ``Engine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so:
|
||||
|
||||
- Non-rollout GPUs also join the communication.
|
||||
- Later on, ``DeviceMesh`` init will fail due to "inconsistent memory".
|
||||
|
@ -15,7 +15,7 @@ data:
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
rollout:
|
||||
name: sglang_async
|
||||
name: sglang
|
||||
multi_turn:
|
||||
enable: True
|
||||
max_turns: 5
|
||||
|
@ -15,7 +15,7 @@ data:
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
rollout:
|
||||
name: sglang_async
|
||||
name: sglang
|
||||
multi_turn:
|
||||
enable: True
|
||||
max_turns: 5
|
||||
|
@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang_async \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
||||
actor_rollout_ref.rollout.n=16 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
||||
@ -41,7 +41,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name='gsm8k_async_rl' \
|
||||
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16' \
|
||||
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=-1 \
|
||||
|
@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang_async \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
||||
actor_rollout_ref.rollout.n=16 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
||||
|
@ -45,7 +45,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang_async \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
||||
actor_rollout_ref.rollout.n=8 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
|
||||
@ -53,7 +53,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name='gsm8k_async_rl' \
|
||||
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \
|
||||
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=-1 \
|
||||
|
@ -17,6 +17,6 @@ torchdata
|
||||
torchvision
|
||||
transformers
|
||||
wandb
|
||||
sglang[all]==0.4.6.post4
|
||||
sglang[all]==0.4.6.post5
|
||||
torch-memory-saver>=0.0.5
|
||||
huggingface_hub
|
||||
huggingface_hub
|
||||
|
7
setup.py
7
setup.py
@ -49,7 +49,12 @@ GEO_REQUIRES = ["mathruler"]
|
||||
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
|
||||
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
|
||||
VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.5"]
|
||||
SGLANG_REQUIRES = ["tensordict<=0.6.2", "sglang[srt,openai]==0.4.6.post4", "torch-memory-saver>=0.0.5", "torch==2.6.0"]
|
||||
SGLANG_REQUIRES = [
|
||||
"tensordict<=0.6.2",
|
||||
"sglang[srt,openai]==0.4.6.post5",
|
||||
"torch-memory-saver>=0.0.5",
|
||||
"torch==2.6.0",
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
"test": TEST_REQUIRES,
|
||||
|
@ -36,7 +36,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang_async \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
||||
actor_rollout_ref.rollout.n=8 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
||||
@ -46,7 +46,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=['console'] \
|
||||
trainer.project_name='gsm8k_async_rl' \
|
||||
trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \
|
||||
trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=-1 \
|
||||
|
@ -78,7 +78,7 @@ if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then
|
||||
CHECKPOINT_CONTENTS=['model','optimizer','extra']
|
||||
fi
|
||||
|
||||
ENGINES=("vllm" "sglang_async")
|
||||
ENGINES=("vllm" "sglang")
|
||||
|
||||
exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal"
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,7 +21,6 @@ from omegaconf import DictConfig
|
||||
@patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"verl.workers.rollout.sglang_rollout.async_sglang_rollout": MagicMock(AsyncSGLangRollout=MagicMock()),
|
||||
"verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()),
|
||||
},
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ from verl.protocol import DataProto
|
||||
from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema
|
||||
from verl.tools.search_tool import SearchTool
|
||||
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
|
||||
from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout
|
||||
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
|
||||
|
||||
DEFAULT_USER_CONTENT_PREFIX = (
|
||||
"Answer the given question. You must conduct reasoning inside <think> and </think> "
|
||||
@ -143,11 +143,11 @@ class TestRolloutWithSearchTools:
|
||||
prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index})
|
||||
return prompts
|
||||
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_tools_registration(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config):
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
assert len(rollout._tool_schemas) == 1
|
||||
assert "search" in rollout._tool_map.keys()
|
||||
from verl.tools.search_tool import SearchTool
|
||||
@ -156,11 +156,11 @@ class TestRolloutWithSearchTools:
|
||||
# depend on the tokenizer
|
||||
assert rollout._tool_call_parser_type == "qwen25"
|
||||
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto):
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)
|
||||
assert len(req_list) == 1
|
||||
assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING
|
||||
@ -186,12 +186,12 @@ class TestRolloutWithSearchTools:
|
||||
),
|
||||
)
|
||||
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data):
|
||||
search_rollout_config.multi_turn.max_turns = 1
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]
|
||||
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
|
||||
req.finalize = MagicMock()
|
||||
@ -223,9 +223,9 @@ class TestRolloutWithSearchTools:
|
||||
)
|
||||
|
||||
@patch.object(SearchTool, "execute", new_callable=AsyncMock)
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data):
|
||||
_, expect_turn_array, tool_return_array = search_data
|
||||
|
||||
@ -233,7 +233,7 @@ class TestRolloutWithSearchTools:
|
||||
mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array]
|
||||
|
||||
search_rollout_config.multi_turn.max_turns = 10
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
|
||||
rollout._tool_map["search"].retrieval_service_url = "mock://dummy"
|
||||
|
||||
@ -272,9 +272,9 @@ class TestRolloutWithSearchTools:
|
||||
assert search_counter == 2
|
||||
|
||||
@patch.object(SearchTool, "execute", new_callable=AsyncMock)
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data):
|
||||
_, expect_turn_array, tool_return_array = search_data
|
||||
|
||||
@ -285,7 +285,7 @@ class TestRolloutWithSearchTools:
|
||||
] * 100
|
||||
|
||||
search_rollout_config.multi_turn.max_turns = 10
|
||||
rollout = AsyncSGLangRollout(
|
||||
rollout = SGLangRollout(
|
||||
actor_module="",
|
||||
config=search_rollout_config,
|
||||
tokenizer=qwen_tokenizer,
|
||||
@ -327,7 +327,7 @@ class TestRolloutWithSearchTools:
|
||||
req_turns_counter[_req.batch_data_id] += 1
|
||||
return await fut
|
||||
|
||||
with patch.object(AsyncSGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
|
||||
with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
|
||||
rollout._tp_rank = 0
|
||||
loop = asyncio.get_event_loop()
|
||||
output_req_list = loop.run_until_complete(asyncio.gather(*[rollout._async_rollout_a_request(r, True, False) for r in req_list]))
|
||||
|
@ -32,7 +32,7 @@ from verl.protocol import DataProto
|
||||
from verl.tools.sandbox_fusion_tools import TokenBucketWorker
|
||||
from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema
|
||||
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message
|
||||
from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout
|
||||
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
|
||||
|
||||
sandbox_url = ""
|
||||
|
||||
@ -200,11 +200,11 @@ class TestRolloutWithTools:
|
||||
prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index})
|
||||
return prompts
|
||||
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_tools_registration(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config):
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
assert len(rollout._tool_schemas) == 1
|
||||
assert "code_interpreter" in rollout._tool_map.keys()
|
||||
from verl.tools.sandbox_fusion_tools import SandboxFusionTool
|
||||
@ -212,11 +212,11 @@ class TestRolloutWithTools:
|
||||
assert isinstance(rollout._tool_map["code_interpreter"], SandboxFusionTool)
|
||||
assert rollout._tool_call_parser_type == "qwen25"
|
||||
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto):
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
req_list = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)
|
||||
assert len(req_list) == 1
|
||||
assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING
|
||||
@ -242,12 +242,12 @@ class TestRolloutWithTools:
|
||||
),
|
||||
)
|
||||
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
|
||||
sandbox_fusion_rollout_config.multi_turn.max_turns = 1
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
|
||||
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
|
||||
req.finalize = MagicMock()
|
||||
@ -279,12 +279,12 @@ class TestRolloutWithTools:
|
||||
)
|
||||
|
||||
@skip_if_valid_sandbox(sandbox_url)
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
|
||||
sandbox_fusion_rollout_config.multi_turn.max_turns = 10
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url
|
||||
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
|
||||
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
|
||||
@ -323,12 +323,12 @@ class TestRolloutWithTools:
|
||||
assert code_counter == 2
|
||||
|
||||
@skip_if_valid_sandbox(sandbox_url)
|
||||
@patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_distributed_env", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
|
||||
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
|
||||
def test_tool_call_batch_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
|
||||
sandbox_fusion_rollout_config.multi_turn.max_turns = 10
|
||||
rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
|
||||
self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url
|
||||
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
|
||||
req_nums = 100
|
||||
@ -357,7 +357,7 @@ class TestRolloutWithTools:
|
||||
re = await result
|
||||
return re
|
||||
|
||||
with patch.object(AsyncSGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
|
||||
with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call):
|
||||
rollout._tp_rank = 0
|
||||
loop = asyncio.get_event_loop()
|
||||
output_req_list = loop.run_until_complete(
|
||||
|
@ -35,8 +35,8 @@ from utils_sglang import (
|
||||
)
|
||||
|
||||
from verl import DataProto
|
||||
from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout
|
||||
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
|
||||
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
|
||||
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
|
||||
|
||||
|
||||
def test_async_sglang_rollout_w_tool():
|
||||
@ -78,9 +78,9 @@ def test_async_sglang_rollout_w_tool():
|
||||
)
|
||||
|
||||
rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, None)
|
||||
rollout = AsyncSGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config)
|
||||
rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config)
|
||||
|
||||
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
|
||||
rollout_sharding_manager = FSDPSGLangShardingManager(
|
||||
module=fsdp_model,
|
||||
inference_engine=rollout._engine,
|
||||
model_config=actor_model.config,
|
||||
@ -111,7 +111,7 @@ def test_async_sglang_rollout_w_tool():
|
||||
|
||||
prompts = rollout_sharding_manager.preprocess_data(prompts)
|
||||
# log_gpu_memory_usage("Before generating sequences", logger=None)
|
||||
output = rollout.generate_sequences_with_tools(prompts=prompts)
|
||||
output = rollout.generate_sequences(prompts=prompts)
|
||||
print(f"generated {output.batch['responses'].shape=}")
|
||||
# log_gpu_memory_usage("After generating sequences", logger=None)
|
||||
output = rollout_sharding_manager.postprocess_data(output)
|
||||
|
@ -1,113 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
usage: torchrun --standalone --nnodes=1 \
|
||||
--nproc_per_node=2 $(which pytest) \
|
||||
-s test_sglang_async_spmd.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.utils import broadcast_pyobj
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from utils_sglang import (
|
||||
are_lists_similar,
|
||||
clean_torchelastic_env,
|
||||
generate_hf_output,
|
||||
initialize_global_process_group,
|
||||
load_tokenizer_and_model,
|
||||
prepare_inputs,
|
||||
)
|
||||
|
||||
|
||||
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
|
||||
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
||||
token_ids = prompt_token_ids[non_pad_index:].tolist()
|
||||
return token_ids
|
||||
|
||||
|
||||
def test_sglang_spmd():
|
||||
assert torch.cuda.device_count() >= 2
|
||||
initialize_global_process_group(spmd=True)
|
||||
clean_torchelastic_env()
|
||||
|
||||
max_prompt_length = 16
|
||||
max_response_length = 16
|
||||
|
||||
local_model_path = "Qwen/Qwen2.5-0.5B"
|
||||
tokenizer, actor_model = load_tokenizer_and_model(local_model_path)
|
||||
|
||||
preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"]
|
||||
input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)
|
||||
|
||||
hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)
|
||||
|
||||
tensor_parallel_size = 2
|
||||
inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
|
||||
tp_rank = inference_device_mesh_cpu["tp"].get_local_rank()
|
||||
|
||||
if tp_rank == 0:
|
||||
llm = Engine(
|
||||
model_path=local_model_path,
|
||||
dtype="bfloat16",
|
||||
mem_fraction_static=0.5,
|
||||
enable_memory_saver=True,
|
||||
tp_size=inference_device_mesh_cpu["tp"].size(),
|
||||
)
|
||||
|
||||
input_ids = input_ids.cuda()
|
||||
idx_list = []
|
||||
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
||||
for i in range(input_ids.shape[0]):
|
||||
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
|
||||
|
||||
sampling_params = dict(
|
||||
n=1,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
top_k=-1,
|
||||
max_new_tokens=max_response_length,
|
||||
presence_penalty=0.0,
|
||||
frequency_penalty=0.0,
|
||||
repetition_penalty=1.0,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=True,
|
||||
ignore_eos=False,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))
|
||||
else:
|
||||
outputs = None
|
||||
|
||||
[outputs] = broadcast_pyobj(
|
||||
[outputs],
|
||||
rank=inference_device_mesh_cpu["tp"].get_local_rank(),
|
||||
src=inference_device_mesh_cpu["tp"].mesh[0].item(),
|
||||
dist_group=inference_device_mesh_cpu["tp"].get_group(),
|
||||
force_cpu_device=False,
|
||||
)
|
||||
|
||||
sglang_response_tokens = [output["text"] for output in outputs]
|
||||
|
||||
print(f"sglang response: {sglang_response_tokens}")
|
||||
assert are_lists_similar(hf_response_tokens, sglang_response_tokens)
|
||||
print("SPMD Test Passed!")
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
@ -12,188 +12,102 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
"""
|
||||
usage: torchrun --standalone --nnodes=1 \
|
||||
--nproc_per_node=2 $(which pytest) \
|
||||
-s test_sglang_async_spmd.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.utils import broadcast_pyobj
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
from verl.utils.torch_functional import pad_sequence_to_length
|
||||
|
||||
|
||||
def levenshtein(s1, s2):
|
||||
m, n = len(s1), len(s2)
|
||||
# Initialize matrix of zeros
|
||||
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
# Initialize first column and first row of the matrix
|
||||
for i in range(m + 1):
|
||||
dp[i][0] = i # Deletion from s1 to empty string
|
||||
for j in range(n + 1):
|
||||
dp[0][j] = j # Insertion to s1 from empty string
|
||||
# Compute the Levenshtein distance matrix
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match
|
||||
dp[i][j] = min(
|
||||
dp[i - 1][j] + 1, # Deletion
|
||||
dp[i][j - 1] + 1, # Insertion
|
||||
dp[i - 1][j - 1] + cost, # Substitution
|
||||
)
|
||||
return dp[m][n]
|
||||
|
||||
|
||||
def are_lists_similar(a, b):
|
||||
if len(a) != len(b):
|
||||
print("The lists are of different lengths.")
|
||||
return False
|
||||
|
||||
total_length = 0
|
||||
total_diff = 0
|
||||
|
||||
for s1, s2 in zip(a, b):
|
||||
max_len = max(len(s1), len(s2))
|
||||
total_length += max_len
|
||||
diff = levenshtein(s1, s2)
|
||||
total_diff += diff
|
||||
print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n")
|
||||
|
||||
percentage_difference = (total_diff / total_length) * 100
|
||||
print(f"Total difference: {percentage_difference:.2f}%")
|
||||
|
||||
return percentage_difference <= 10
|
||||
|
||||
|
||||
def initialize_global_process_group(timeout_second=36000):
|
||||
from datetime import timedelta
|
||||
|
||||
import torch.distributed
|
||||
|
||||
# NOTE MODIFIED should provide backend=None to have nccl+gloo
|
||||
# torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
|
||||
torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch.cuda.set_device(local_rank)
|
||||
return local_rank, rank, world_size
|
||||
|
||||
|
||||
def test_sglang_spmd():
|
||||
assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests."
|
||||
initialize_global_process_group()
|
||||
# fill rollout config
|
||||
max_prompt_length = 16
|
||||
max_response_length = 16
|
||||
|
||||
# Initialize model and token
|
||||
local_cache_path = "~/.cache/verl/rlhf"
|
||||
local_cache_path = os.path.expanduser(local_cache_path)
|
||||
hdfs_path = "Qwen/Qwen2-7B-Instruct"
|
||||
from verl.utils.fs import copy_to_local
|
||||
|
||||
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
|
||||
|
||||
preencode_prompts = [
|
||||
"Who won the Champions League in 2019?",
|
||||
"The founder of Apple is",
|
||||
"What's your name?",
|
||||
]
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True)
|
||||
input_ids = prompts["input_ids"]
|
||||
attention_mask = prompts["attention_mask"]
|
||||
|
||||
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
|
||||
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
|
||||
|
||||
actor_model = AutoModelForCausalLM.from_pretrained(local_model_path)
|
||||
actor_model.to(torch.bfloat16)
|
||||
|
||||
sampling_params = dict(
|
||||
n=1,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
top_k=-1,
|
||||
max_new_tokens=max_response_length,
|
||||
presence_penalty=0.0,
|
||||
frequency_penalty=0.0,
|
||||
repetition_penalty=1.0,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=True,
|
||||
ignore_eos=False,
|
||||
)
|
||||
|
||||
tensor_parallel_size = 4
|
||||
device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
|
||||
inference_device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
|
||||
|
||||
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
|
||||
if k in os.environ:
|
||||
del os.environ[k]
|
||||
print("building sglang rollout engine")
|
||||
llm = VerlEngine(
|
||||
model_path=local_model_path,
|
||||
dtype="bfloat16",
|
||||
mem_fraction_static=0.5,
|
||||
device_mesh_cpu=inference_device_mesh_cpu["tp"],
|
||||
base_gpu_id=0,
|
||||
gpu_id_step=1,
|
||||
)
|
||||
|
||||
llm.release_memory_occupation()
|
||||
print("start generation")
|
||||
input_ids = input_ids.cuda()
|
||||
attention_mask = attention_mask.cuda()
|
||||
batch_size = input_ids.size(0)
|
||||
|
||||
generation_config = GenerationConfig(do_sample=False)
|
||||
actor_model.cuda()
|
||||
output = actor_model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_response_length,
|
||||
# max_length=max_length,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
generation_config=generation_config,
|
||||
# renormalize_logits=True,
|
||||
output_scores=False, # this is potentially very large
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
) # may OOM when use_cache = True
|
||||
seq = output.sequences
|
||||
response = seq[:, max_prompt_length:]
|
||||
|
||||
hf_response_tokens = tokenizer.batch_decode(response)
|
||||
print(f"hf response: {hf_response_tokens}")
|
||||
print(f"{sampling_params=}")
|
||||
idx_list = []
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
||||
for i in range(batch_size):
|
||||
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
|
||||
|
||||
outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params)
|
||||
sglang_response_tokens = []
|
||||
|
||||
for output in outputs:
|
||||
print(f"{output=}")
|
||||
generated_text = output["text"]
|
||||
sglang_response_tokens.append(generated_text)
|
||||
|
||||
print(f"sglang response: {sglang_response_tokens}")
|
||||
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n"
|
||||
print("Check Pass")
|
||||
from utils_sglang import (
|
||||
are_lists_similar,
|
||||
clean_torchelastic_env,
|
||||
generate_hf_output,
|
||||
initialize_global_process_group,
|
||||
load_tokenizer_and_model,
|
||||
prepare_inputs,
|
||||
)
|
||||
|
||||
|
||||
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
|
||||
# remove the left padding in the prompt token_id
|
||||
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
||||
token_ids = prompt_token_ids[non_pad_index:].tolist()
|
||||
return token_ids
|
||||
|
||||
|
||||
def test_sglang_spmd():
|
||||
assert torch.cuda.device_count() >= 2
|
||||
initialize_global_process_group(spmd=True)
|
||||
clean_torchelastic_env()
|
||||
|
||||
max_prompt_length = 16
|
||||
max_response_length = 16
|
||||
|
||||
local_model_path = "Qwen/Qwen2.5-0.5B"
|
||||
tokenizer, actor_model = load_tokenizer_and_model(local_model_path)
|
||||
|
||||
preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"]
|
||||
input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)
|
||||
|
||||
hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)
|
||||
|
||||
tensor_parallel_size = 2
|
||||
inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
|
||||
tp_rank = inference_device_mesh_cpu["tp"].get_local_rank()
|
||||
|
||||
if tp_rank == 0:
|
||||
llm = Engine(
|
||||
model_path=local_model_path,
|
||||
dtype="bfloat16",
|
||||
mem_fraction_static=0.5,
|
||||
enable_memory_saver=True,
|
||||
tp_size=inference_device_mesh_cpu["tp"].size(),
|
||||
)
|
||||
|
||||
input_ids = input_ids.cuda()
|
||||
idx_list = []
|
||||
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
||||
for i in range(input_ids.shape[0]):
|
||||
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
|
||||
|
||||
sampling_params = dict(
|
||||
n=1,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
top_k=-1,
|
||||
max_new_tokens=max_response_length,
|
||||
presence_penalty=0.0,
|
||||
frequency_penalty=0.0,
|
||||
repetition_penalty=1.0,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=True,
|
||||
ignore_eos=False,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))
|
||||
else:
|
||||
outputs = None
|
||||
|
||||
[outputs] = broadcast_pyobj(
|
||||
[outputs],
|
||||
rank=inference_device_mesh_cpu["tp"].get_local_rank(),
|
||||
src=inference_device_mesh_cpu["tp"].mesh[0].item(),
|
||||
dist_group=inference_device_mesh_cpu["tp"].get_group(),
|
||||
force_cpu_device=False,
|
||||
)
|
||||
|
||||
sglang_response_tokens = [output["text"] for output in outputs]
|
||||
|
||||
print(f"sglang response: {sglang_response_tokens}")
|
||||
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n"
|
||||
print("SPMD Test Passed!")
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
|
@ -37,7 +37,7 @@ def levenshtein(s1, s2):
|
||||
return dp[m][n]
|
||||
|
||||
|
||||
def are_lists_similar(a, b):
|
||||
def are_lists_similar(a, b, threshold=10):
|
||||
if len(a) != len(b):
|
||||
print("The lists are of different lengths.")
|
||||
return False
|
||||
@ -49,7 +49,7 @@ def are_lists_similar(a, b):
|
||||
total_diff += levenshtein(s1, s2)
|
||||
percentage_difference = (total_diff / total_length) * 100
|
||||
print(f"Total difference: {percentage_difference:.2f}%")
|
||||
return percentage_difference <= 10
|
||||
return percentage_difference <= threshold
|
||||
|
||||
|
||||
def initialize_global_process_group(timeout_second=36000, spmd=False):
|
||||
|
@ -168,7 +168,7 @@ actor_rollout_ref:
|
||||
n: 1
|
||||
do_sample: False # default eager for validation
|
||||
multi_turn:
|
||||
enable: False # should set rollout.name to sglang_async if True
|
||||
enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
|
||||
max_turns: null # null for no limit (default max_length // 3)
|
||||
tool_config_path: null # null for no tool
|
||||
format: chatml # chatml, more formats will be supported in the future
|
||||
|
@ -140,7 +140,7 @@ actor_rollout_ref:
|
||||
n: 1
|
||||
do_sample: False # default eager for validation
|
||||
multi_turn:
|
||||
enable: False # should set rollout.name to sglang_async if True
|
||||
enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
|
||||
max_turns: null # null for no limit (default max_length // 3)
|
||||
tool_config_path: null # null for no tool
|
||||
format: chatml # chatml, more formats will be supported in the future
|
||||
|
@ -448,86 +448,39 @@ class ActorRolloutRefWorker(Worker):
|
||||
)
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
|
||||
elif rollout_name == "sglang":
|
||||
if self.config.rollout.mode == "sync":
|
||||
from verl.workers.rollout.sglang_rollout import SGLangRollout
|
||||
|
||||
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
|
||||
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
|
||||
# the main process of ray can not find any CUDA device, which would potentially lead to:
|
||||
# "RuntimeError: No CUDA GPUs are available".
|
||||
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
|
||||
# we import it here use the abs path.
|
||||
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
|
||||
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
|
||||
|
||||
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
|
||||
local_path = copy_to_local(self.config.model.path)
|
||||
rollout = SGLangRollout(
|
||||
actor_module=local_path,
|
||||
config=self.config.rollout,
|
||||
tokenizer=self.tokenizer,
|
||||
model_hf_config=self.actor_model_config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
elif rollout_name in ["sglang", "sglang_async"]:
|
||||
if rollout_name == "sglang_async":
|
||||
warnings.warn(
|
||||
"'sglang_async' has been deprecated and merged into 'sglang'. "
|
||||
"Please use 'sglang' going forward.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
|
||||
from verl.workers.rollout.sglang_rollout import SGLangRollout
|
||||
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
|
||||
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
|
||||
# the main process of ray can not find any CUDA device, which would potentially lead to:
|
||||
# "RuntimeError: No CUDA GPUs are available".
|
||||
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
|
||||
# we import it here use the abs path.
|
||||
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
|
||||
|
||||
if torch.distributed.get_world_size() == 1:
|
||||
self.config.rollout.load_format = "dummy_hf"
|
||||
rollout_sharding_manager = FSDPSGLangShardingManager(
|
||||
module=self.actor_module_fsdp,
|
||||
inference_engine=rollout.inference_engine,
|
||||
model_config=self.actor_model_config,
|
||||
full_params="hf" in self.config.rollout.load_format,
|
||||
device_mesh=rollout_device_mesh,
|
||||
offload_param=self._is_offload_param,
|
||||
)
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
elif self.config.rollout.mode == "async":
|
||||
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
|
||||
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
|
||||
|
||||
local_path = copy_to_local(self.config.model.path)
|
||||
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
|
||||
rollout = AsyncSGLangRollout(
|
||||
actor_module=local_path,
|
||||
config=self.config.rollout,
|
||||
tokenizer=self.tokenizer,
|
||||
model_hf_config=self.actor_model_config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)
|
||||
|
||||
if torch.distributed.get_world_size() == 1:
|
||||
self.config.rollout.load_format = "dummy_hf"
|
||||
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
|
||||
module=self.actor_module_fsdp,
|
||||
inference_engine=rollout._engine,
|
||||
model_config=self.actor_model_config,
|
||||
full_params="hf" in self.config.rollout.load_format,
|
||||
device_mesh=rollout_device_mesh,
|
||||
offload_param=self._is_offload_param,
|
||||
)
|
||||
log_gpu_memory_usage("After building sharding manager", logger=None)
|
||||
elif rollout_name == "sglang_async":
|
||||
# TODO replace by rollout.mode == "async"
|
||||
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
|
||||
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
|
||||
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
|
||||
|
||||
local_path = copy_to_local(self.config.model.path)
|
||||
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
|
||||
rollout = AsyncSGLangRollout(
|
||||
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
|
||||
rollout = SGLangRollout(
|
||||
actor_module=local_path,
|
||||
config=self.config.rollout,
|
||||
tokenizer=self.tokenizer,
|
||||
model_hf_config=self.actor_model_config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)
|
||||
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
|
||||
|
||||
if torch.distributed.get_world_size() == 1:
|
||||
self.config.rollout.load_format = "dummy_hf"
|
||||
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
|
||||
rollout_sharding_manager = FSDPSGLangShardingManager(
|
||||
module=self.actor_module_fsdp,
|
||||
inference_engine=rollout._engine,
|
||||
model_config=self.actor_model_config,
|
||||
@ -535,7 +488,7 @@ class ActorRolloutRefWorker(Worker):
|
||||
device_mesh=rollout_device_mesh,
|
||||
offload_param=self._is_offload_param,
|
||||
)
|
||||
log_gpu_memory_usage("After building sharding manager", logger=None)
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")
|
||||
@ -696,16 +649,8 @@ class ActorRolloutRefWorker(Worker):
|
||||
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
|
||||
|
||||
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
|
||||
|
||||
if self.config.rollout.name == "sglang_async":
|
||||
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
|
||||
|
||||
if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0:
|
||||
output = self.rollout.generate_sequences_with_tools(prompts=prompts)
|
||||
else:
|
||||
output = self.rollout.generate_sequences(prompts=prompts)
|
||||
else:
|
||||
output = self.rollout.generate_sequences(prompts=prompts)
|
||||
output = self.rollout.generate_sequences(prompts=prompts)
|
||||
|
||||
log_gpu_memory_usage("After rollout generation", logger=logger)
|
||||
|
||||
output = self.rollout_sharding_manager.postprocess_data(output)
|
||||
|
@ -18,6 +18,7 @@ The main entry point to run the PPO algorithm
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -258,7 +259,14 @@ class ActorRolloutRefWorker(MegatronWorker):
|
||||
weight_converter=weight_converter,
|
||||
)
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
elif self.config.rollout.name == "sglang":
|
||||
|
||||
elif self.config.rollout.name in ["sglang", "sglang_async"]:
|
||||
if self.config.rollout.name == "sglang_async":
|
||||
warnings.warn(
|
||||
"'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from verl.workers.rollout.sglang_rollout import SGLangRollout
|
||||
|
||||
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
|
||||
@ -276,45 +284,6 @@ class ActorRolloutRefWorker(MegatronWorker):
|
||||
local_path = copy_to_local(self.config.model.path)
|
||||
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None)
|
||||
rollout = SGLangRollout(
|
||||
actor_module=local_path,
|
||||
config=self.config.rollout,
|
||||
tokenizer=self.tokenizer,
|
||||
model_hf_config=self.actor_model_config,
|
||||
trust_remote_code=self.config.model.get("trust_remote_code", False),
|
||||
)
|
||||
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None)
|
||||
|
||||
from verl.models.mcore import get_mcore_weight_converter
|
||||
|
||||
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
|
||||
sharding_manager = MegatronSGLangShardingManager(
|
||||
actor_module=self.actor.actor_module,
|
||||
inference_engine=rollout.inference_engine,
|
||||
model_config=self.actor_model_config,
|
||||
transformer_config=self.tf_config,
|
||||
layer_name_mapping=layer_name_mapping,
|
||||
weight_converter=weight_converter,
|
||||
device_mesh=rollout_device_mesh,
|
||||
)
|
||||
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
||||
elif self.config.rollout.name == "sglang_async":
|
||||
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
|
||||
|
||||
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
|
||||
# However, due to verl's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
|
||||
# "RuntimeError: No CUDA GPUs are available".
|
||||
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path.
|
||||
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
|
||||
from verl.workers.sharding_manager.megatron_sglang import MegatronAsyncSGLangShardingManager
|
||||
|
||||
infer_tp = self.config.rollout.tensor_model_parallel_size
|
||||
dp = self.world_size // infer_tp
|
||||
assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
|
||||
rollout_device_mesh = init_device_mesh("cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp"))
|
||||
|
||||
local_path = copy_to_local(self.config.model.path)
|
||||
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None)
|
||||
rollout = AsyncSGLangRollout(
|
||||
actor_module=local_path,
|
||||
config=self.config.rollout,
|
||||
tokenizer=self.tokenizer,
|
||||
@ -327,7 +296,7 @@ class ActorRolloutRefWorker(MegatronWorker):
|
||||
from verl.models.mcore import get_mcore_weight_converter
|
||||
|
||||
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
|
||||
sharding_manager = MegatronAsyncSGLangShardingManager(
|
||||
sharding_manager = MegatronSGLangShardingManager(
|
||||
actor_module=self.actor.actor_module,
|
||||
inference_engine=rollout._engine,
|
||||
model_config=self.actor_model_config,
|
||||
@ -491,16 +460,7 @@ class ActorRolloutRefWorker(MegatronWorker):
|
||||
log_gpu_memory_usage("After entering sharding manager", logger=logger)
|
||||
|
||||
prompts = self.sharding_manager.preprocess_data(prompts)
|
||||
# output = self.rollout.generate_sequences(prompts=prompts)
|
||||
if self.config.rollout.name == "sglang_async":
|
||||
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
|
||||
|
||||
if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0:
|
||||
output = self.rollout.generate_sequences_with_tools(prompts=prompts)
|
||||
else:
|
||||
output = self.rollout.generate_sequences(prompts=prompts)
|
||||
else:
|
||||
output = self.rollout.generate_sequences(prompts=prompts)
|
||||
output = self.rollout.generate_sequences(prompts=prompts)
|
||||
output = self.sharding_manager.postprocess_data(output)
|
||||
|
||||
output = output.to("cpu")
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
from .async_sglang_rollout import AsyncSGLangRollout
|
||||
from .sglang_rollout import SGLangRollout
|
||||
|
||||
__all__ = ["AsyncSGLangRollout", "SGLangRollout"]
|
||||
__all__ = ["SGLangRollout"]
|
||||
|
@ -1,876 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from json import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from omegaconf import DictConfig
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.utils import get_ip, get_open_port
|
||||
from tensordict import TensorDict
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from verl import DataProto
|
||||
from verl.third_party.sglang import parallel_state as sglang_ps
|
||||
from verl.tools.base_tool import BaseTool
|
||||
from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall
|
||||
from verl.utils.debug import GPUMemoryLogger
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
from verl.utils.net_utils import is_ipv6
|
||||
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
|
||||
from verl.workers.rollout.base import BaseRollout
|
||||
from verl.workers.rollout.schemas import (
|
||||
AsyncRolloutRequest,
|
||||
AsyncRolloutRequestStateEnum,
|
||||
FinishReasonTypeEnum,
|
||||
Message,
|
||||
)
|
||||
from verl.workers.rollout.sglang_rollout.sglang_rollout import _post_process_outputs, _pre_process_inputs
|
||||
from verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str:
|
||||
for parser_type, parser_cls in FunctionCallParser.ToolCallParserEnum.items():
|
||||
parser = parser_cls()
|
||||
if parser.bot_token in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab()):
|
||||
return parser_type
|
||||
else:
|
||||
raise ValueError(f"No tool call parser found for tokenizer {tokenizer}")
|
||||
|
||||
|
||||
class AsyncSGLangRollout(BaseRollout):
|
||||
def __init__(
|
||||
self,
|
||||
actor_module: nn.Module | str,
|
||||
config: DictConfig,
|
||||
tokenizer,
|
||||
model_hf_config,
|
||||
port=None,
|
||||
trust_remote_code: bool = False,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""A SGLang rollout. It requires the module is supported by the SGLang.
|
||||
|
||||
Args:
|
||||
actor_module: module here follows huggingface APIs
|
||||
config: DictConfig
|
||||
tokenizer: the task/model tokenizer
|
||||
model_hf_config: the huggingface config to initiallize the generating model in SGLang
|
||||
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self._tool_schemas, self._tool_map, self._tool_call_parser_type, self._sgl_tools, self._function_call_parser = self._initialize_tools(config, tokenizer)
|
||||
assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"
|
||||
logger.info(f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: {self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: {self._function_call_parser}")
|
||||
|
||||
self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs)
|
||||
|
||||
self._verify_config(model_hf_config=model_hf_config)
|
||||
# initialize the inference engine
|
||||
self._init_inference_engine(trust_remote_code, actor_module, port)
|
||||
|
||||
self._init_sampling_params(**kwargs)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def _init_distributed_env(self, device_mesh_cpu, **kwargs):
|
||||
self._device_mesh_cpu = device_mesh_cpu
|
||||
os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")
|
||||
self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
|
||||
assert self.tensor_parallel_size <= dist.get_world_size(), "tensor parallel size should be less than or equal to the world size"
|
||||
self.train_tp = kwargs.get("train_tp", None)
|
||||
if self.train_tp is not None:
|
||||
# deployed with megatron
|
||||
os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
|
||||
os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
|
||||
train_tp = kwargs.get("train_tp", None)
|
||||
num_tp_per_train_tp = train_tp // self.tensor_parallel_size
|
||||
sglang_ps.initialize_parallel_state(
|
||||
tensor_model_parallel_size=self.tensor_parallel_size,
|
||||
num_tp_per_train_tp=num_tp_per_train_tp,
|
||||
)
|
||||
|
||||
tp_size = self.tensor_parallel_size
|
||||
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||
|
||||
# init device mesh
|
||||
if self._device_mesh_cpu is None:
|
||||
device_mesh_kwargs = dict(
|
||||
mesh_shape=(world_size // tp_size, tp_size, 1),
|
||||
mesh_dim_names=["dp", "tp", "pp"],
|
||||
)
|
||||
|
||||
self._device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
|
||||
|
||||
self._rank = self._device_mesh_cpu.get_rank()
|
||||
self._tp_rank = self._device_mesh_cpu["tp"].get_local_rank()
|
||||
self._tp_size = self._device_mesh_cpu["tp"].size()
|
||||
if self._rank == 0:
|
||||
logger.info(f"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}")
|
||||
# get tp_rank of this process in this tp group
|
||||
visible_devices = [None] * self._device_mesh_cpu.size(1)
|
||||
|
||||
torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp"))
|
||||
self.visible_devices_set = set(",".join(visible_devices).split(","))
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(self.visible_devices_set)))
|
||||
|
||||
def _verify_config(self, model_hf_config):
|
||||
if not self.config.get("max_model_len", None):
|
||||
self.config.max_model_len = self.config.prompt_length + self.config.response_length
|
||||
assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length):
|
||||
{self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}"""
|
||||
assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length"
|
||||
# currently max_turns stand for max number of tool calls
|
||||
if self.config.multi_turn.max_turns is None:
|
||||
self.config.multi_turn.max_turns = self.config.max_model_len // 3
|
||||
|
||||
def _init_inference_engine(self, trust_remote_code, actor_module, port):
|
||||
# initialize the inference engine
|
||||
nnodes = -(-self._tp_size // len(self.visible_devices_set))
|
||||
if nnodes > 1:
|
||||
ip = get_ip()
|
||||
port = get_open_port() if port is None else port
|
||||
[ip, port] = broadcast_pyobj(
|
||||
[ip, port],
|
||||
rank=self._rank,
|
||||
dist_group=self._device_mesh_cpu.get_group("tp"),
|
||||
src=self._device_mesh_cpu["tp"].mesh[0].item(),
|
||||
force_cpu_device=False,
|
||||
)
|
||||
dist_init_addr = f"[{ip}]:{port}" if is_ipv6(ip) else f"{ip}:{port}"
|
||||
else:
|
||||
dist_init_addr = None
|
||||
|
||||
load_format = "dummy" if self.config.load_format.startswith("dummy") else self.config.load_format
|
||||
tp_size_per_node = self._tp_size // nnodes
|
||||
node_rank = self._tp_rank // tp_size_per_node
|
||||
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
||||
|
||||
if first_rank_in_node:
|
||||
rank = dist.get_rank()
|
||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||
self._engine = Engine(
|
||||
model_path=actor_module,
|
||||
dtype=self.config.dtype,
|
||||
mem_fraction_static=self.config.gpu_memory_utilization,
|
||||
enable_memory_saver=True,
|
||||
base_gpu_id=0,
|
||||
gpu_id_step=1,
|
||||
tp_size=self._tp_size,
|
||||
node_rank=node_rank,
|
||||
load_format=load_format,
|
||||
dist_init_addr=dist_init_addr,
|
||||
nnodes=nnodes,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new
|
||||
# when random.seed is being set during training
|
||||
port=30000 + rank,
|
||||
# NOTE(Chenyang): if you want to debug the SGLang engine output
|
||||
# please set the following parameters
|
||||
# Otherwise, it will make the engine run too slow
|
||||
# log_level="INFO",
|
||||
# log_requests=True,
|
||||
# log_requests_level=2,
|
||||
# max_running_requests=1,
|
||||
)
|
||||
else:
|
||||
self._engine = None
|
||||
|
||||
self.sharding_manager = None
|
||||
# offload
|
||||
if self._tp_rank == 0:
|
||||
self._engine.release_memory_occupation()
|
||||
self.is_sleep = True
|
||||
|
||||
def _init_sampling_params(self, **kwargs):
|
||||
kwargs = dict(
|
||||
n=1,
|
||||
max_new_tokens=self.config.response_length,
|
||||
presence_penalty=0.0,
|
||||
frequency_penalty=0.0,
|
||||
repetition_penalty=1.0,
|
||||
)
|
||||
# supporting adding any sampling params from the config file
|
||||
for k in self.config.keys():
|
||||
if hasattr(SamplingParams(), str(k)):
|
||||
kwargs[k] = self.config.get(k)
|
||||
self.sampling_params = kwargs
|
||||
|
||||
def _initialize_tools(self, config, tokenizer):
|
||||
"""Initialize tools from configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration object containing tool settings
|
||||
tokenizer: Tokenizer instance for tool call parsing
|
||||
|
||||
Returns:
|
||||
tuple: (tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser)
|
||||
"""
|
||||
if config.multi_turn.tool_config_path is None:
|
||||
return [], {}, None, [], None
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.tools.schemas import OpenAIFunctionToolSchema
|
||||
|
||||
def initialize_tools_from_config(tools_config) -> list:
|
||||
tool_list = []
|
||||
|
||||
for tool_config in tools_config.tools:
|
||||
cls_name = tool_config.class_name
|
||||
module_name, class_name = cls_name.rsplit(".", 1)
|
||||
|
||||
if module_name not in sys.modules:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
else:
|
||||
module = sys.modules[module_name]
|
||||
|
||||
tool_cls = getattr(module, class_name)
|
||||
|
||||
tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True)
|
||||
tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict)
|
||||
|
||||
tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema)
|
||||
tool_list.append(tool)
|
||||
|
||||
return tool_list
|
||||
|
||||
tools_config_file = config.multi_turn.tool_config_path
|
||||
tools_config = OmegaConf.load(tools_config_file)
|
||||
tool_list = initialize_tools_from_config(tools_config)
|
||||
logger.info(f"Initialize tools from configuration.: tool_list: {tool_list}")
|
||||
tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list]
|
||||
tool_map = {tool.name: tool for tool in tool_list}
|
||||
tool_call_parser_type = get_tool_call_parser_type(tokenizer)
|
||||
sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas]
|
||||
function_call_parser = FunctionCallParser(
|
||||
sgl_tools,
|
||||
tool_call_parser_type,
|
||||
)
|
||||
|
||||
return tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser
|
||||
|
||||
@contextmanager
|
||||
def update_sampling_params(self, **kwargs):
|
||||
# update sampling params
|
||||
old_sampling_params_args = {}
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if key in self.sampling_params:
|
||||
old_value = self.sampling_params[key]
|
||||
old_sampling_params_args[key] = old_value
|
||||
self.sampling_params[key] = value
|
||||
yield
|
||||
# roll back to previous sampling params
|
||||
# if len(old_sampling_params_args):
|
||||
for key, value in old_sampling_params_args.items():
|
||||
self.sampling_params[key] = value
|
||||
|
||||
@GPUMemoryLogger(role="sglang async rollout", logger=logger)
|
||||
@torch.no_grad()
|
||||
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
|
||||
# if self.config.free_cache_engine:
|
||||
|
||||
idx = prompts.batch["input_ids"] # (bs, prompt_length)
|
||||
# left-padded attention_mask
|
||||
attention_mask = prompts.batch["attention_mask"]
|
||||
position_ids = prompts.batch["position_ids"]
|
||||
|
||||
# used to construct attention_mask
|
||||
eos_token_id = prompts.meta_info["eos_token_id"]
|
||||
|
||||
batch_size = idx.size(0)
|
||||
|
||||
# Extract non-tensor data
|
||||
non_tensor_batch = prompts.non_tensor_batch
|
||||
if "raw_prompt_ids" not in non_tensor_batch:
|
||||
non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)
|
||||
|
||||
if "multi_modal_data" in non_tensor_batch:
|
||||
sglang_inputs = []
|
||||
for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")):
|
||||
sglang_inputs.append(
|
||||
{
|
||||
"prompt_token_ids": raw_prompt_ids,
|
||||
"multi_modal_data": multi_modal_data,
|
||||
"image_data": multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
sglang_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")]
|
||||
|
||||
# Ensure token IDs are lists
|
||||
for input_data in sglang_inputs:
|
||||
if isinstance(input_data["prompt_token_ids"], np.ndarray):
|
||||
input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist()
|
||||
elif not isinstance(input_data["prompt_token_ids"], list):
|
||||
raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}")
|
||||
|
||||
# Extract token IDs and image data for SGLang Engine
|
||||
idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs]
|
||||
image_list = [input_data.get("image_data", None) for input_data in sglang_inputs]
|
||||
|
||||
do_sample = prompts.meta_info.get("do_sample", True)
|
||||
is_validate = prompts.meta_info.get("validate", False)
|
||||
if not do_sample:
|
||||
kwargs = dict(
|
||||
n=1,
|
||||
presence_penalty=0.0,
|
||||
frequency_penalty=0.0,
|
||||
repetition_penalty=1.0,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
top_k=-1,
|
||||
ignore_eos=False,
|
||||
min_new_tokens=0,
|
||||
max_new_tokens=self.config.response_length,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
elif is_validate:
|
||||
kwargs = dict(
|
||||
top_k=self.config.val_kwargs.top_k,
|
||||
top_p=self.config.val_kwargs.top_p,
|
||||
temperature=self.config.val_kwargs.temperature,
|
||||
n=1, # if validate, already repeat in ray_trainer
|
||||
)
|
||||
|
||||
# users can customize different sampling_params at different run
|
||||
with self.update_sampling_params(**kwargs):
|
||||
# print(f"{self.sampling_params=}")
|
||||
if self._tp_rank == 0:
|
||||
loop = asyncio.get_event_loop()
|
||||
output = loop.run_until_complete(
|
||||
self._engine.async_generate(
|
||||
prompt=None, # because we have already convert it to prompt token id
|
||||
sampling_params=self.sampling_params,
|
||||
return_logprob=True,
|
||||
input_ids=idx_list,
|
||||
image_data=image_list,
|
||||
)
|
||||
)
|
||||
else:
|
||||
output = None
|
||||
# Most naive implementation, can extract tensor and send via gloo if too slow
|
||||
[output] = broadcast_pyobj(
|
||||
data=[output],
|
||||
rank=self._rank,
|
||||
dist_group=self._device_mesh_cpu["tp"].get_group(),
|
||||
src=self._device_mesh_cpu["tp"].mesh[0].item(),
|
||||
force_cpu_device=False,
|
||||
)
|
||||
out = _post_process_outputs(self.tokenizer, output)
|
||||
|
||||
response = out[0].to(idx.device)
|
||||
rollout_log_probs = out[1].to(idx.device)
|
||||
|
||||
if response.shape[1] < self.config.response_length:
|
||||
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
|
||||
rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id)
|
||||
|
||||
# utilize current sampling params
|
||||
if self.sampling_params.get("n", 1) > 1 and do_sample:
|
||||
idx = idx.repeat_interleave(self.sampling_params["n"], dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.sampling_params["n"], dim=0)
|
||||
position_ids = position_ids.repeat_interleave(self.sampling_params["n"], dim=0)
|
||||
batch_size = batch_size * self.sampling_params["n"]
|
||||
_non_tensor_batch = {}
|
||||
for key, val in non_tensor_batch.items():
|
||||
_non_tensor_batch[key] = np.repeat(val, self.sampling_params["n"], axis=0)
|
||||
else:
|
||||
_non_tensor_batch = non_tensor_batch
|
||||
seq = torch.cat([idx, response], dim=-1)
|
||||
|
||||
response_length = response.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# TODO(sgm): fix position_ids on right_pad
|
||||
# prompt: left pad + response: right pad
|
||||
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
# all the tp ranks should contain the same data here. data in all ranks are valid
|
||||
batch = TensorDict(
|
||||
{
|
||||
"prompts": idx,
|
||||
"responses": response,
|
||||
"input_ids": seq, # here input_ids become the whole sentences
|
||||
"rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
},
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
# free cache engine
|
||||
if self.config.free_cache_engine and self._engine is not None:
|
||||
self._engine.flush_cache()
|
||||
|
||||
return DataProto(batch=batch, non_tensor_batch=_non_tensor_batch)
|
||||
|
||||
async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bool = True, is_validate: bool = False, **kwargs) -> AsyncRolloutRequest:
|
||||
assert self._tp_rank == 0, "only the master process can call this function"
|
||||
_req = deepcopy(req)
|
||||
finish_reason_type = None
|
||||
output = None
|
||||
|
||||
current_turns = 0
|
||||
while current_turns < self.config.multi_turn.max_turns:
|
||||
if _req.state == AsyncRolloutRequestStateEnum.PENDING:
|
||||
await self._handle_pending_state(_req)
|
||||
_req.state = AsyncRolloutRequestStateEnum.RUNNING
|
||||
elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:
|
||||
if _req.messages[-1].tool_calls is not None:
|
||||
parsed_tool_calls = _req.messages[-1].tool_calls
|
||||
tool_call_results = await asyncio.gather(
|
||||
*[
|
||||
self._tool_map[tool_call.function.name].execute(
|
||||
_req.request_id,
|
||||
tool_call.function.arguments,
|
||||
**_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}),
|
||||
)
|
||||
for tool_call in parsed_tool_calls
|
||||
]
|
||||
)
|
||||
for i, (tool_call, (resp, reward, metrics)) in enumerate(zip(parsed_tool_calls, tool_call_results)):
|
||||
_req.add_tool_response_message(self.tokenizer, resp, (i == len(parsed_tool_calls) - 1), format=self.config.multi_turn.format)
|
||||
_req.update_metrics(metrics, tool_call.function.name)
|
||||
if len(_req.input_ids) >= self.config.max_model_len:
|
||||
break
|
||||
if len(_req.input_ids) >= self.config.max_model_len:
|
||||
finish_reason_type = FinishReasonTypeEnum.STOP
|
||||
break
|
||||
_req.state = AsyncRolloutRequestStateEnum.RUNNING
|
||||
else:
|
||||
raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}")
|
||||
elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:
|
||||
output = await self._handle_engine_call(_req, do_sample, is_validate, **kwargs)
|
||||
content = output["text"]
|
||||
finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"])
|
||||
current_turns += 1
|
||||
if finish_reason_type == FinishReasonTypeEnum.LENGTH:
|
||||
_req.add_assistant_message(self.tokenizer, content, already_over_long=True, format=self.config.multi_turn.format)
|
||||
break
|
||||
else:
|
||||
if self._function_call_parser and self._function_call_parser.has_tool_call(content):
|
||||
finish_reason_type = FinishReasonTypeEnum.TOOL_CALL
|
||||
_req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING
|
||||
try:
|
||||
normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)
|
||||
except JSONDecodeError:
|
||||
normed_content = content
|
||||
tool_calls = []
|
||||
except AttributeError:
|
||||
normed_content = content
|
||||
tool_calls = []
|
||||
parsed_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(OpenAIFunctionParsedSchema(name=tool_call.name, arguments=tool_call.parameters))
|
||||
# Drop the tool call if its arguments has decode error
|
||||
if has_decode_error:
|
||||
continue
|
||||
parsed_tool_calls.append(
|
||||
OpenAIFunctionToolCall(
|
||||
id=str(tool_call.tool_index),
|
||||
function=function,
|
||||
)
|
||||
)
|
||||
if len(parsed_tool_calls) > 0:
|
||||
_req.add_assistant_message(
|
||||
self.tokenizer,
|
||||
normed_content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
format=self.config.multi_turn.format,
|
||||
)
|
||||
else:
|
||||
_req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format)
|
||||
finish_reason_type = FinishReasonTypeEnum.STOP
|
||||
_req.state = AsyncRolloutRequestStateEnum.COMPLETED
|
||||
break
|
||||
else:
|
||||
_req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format)
|
||||
break
|
||||
|
||||
if current_turns >= self.config.multi_turn.max_turns:
|
||||
finish_reason_type = FinishReasonTypeEnum.STOP
|
||||
|
||||
# Calculate the reward for each tool
|
||||
async def calc_reward_and_release_fn(name: str, tool: BaseTool):
|
||||
reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {}))
|
||||
await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {}))
|
||||
return name, reward
|
||||
|
||||
tool_reward_tasks = []
|
||||
for name in _req.tools_kwargs.keys():
|
||||
tool = self._tool_map[name]
|
||||
tool_reward_tasks.append(calc_reward_and_release_fn(name, tool))
|
||||
tool_reward_scores = await asyncio.gather(*tool_reward_tasks)
|
||||
tool_reward_scores = dict(tool_reward_scores)
|
||||
_req.finalize(self.tokenizer, tool_reward_scores, finish_reason_type)
|
||||
|
||||
return _req
|
||||
|
||||
async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs) -> dict:
|
||||
generation_prompt_ids = _req.get_generation_prompt(self.tokenizer)
|
||||
max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)
|
||||
if not do_sample:
|
||||
kwargs = dict(
|
||||
n=1,
|
||||
presence_penalty=0.0,
|
||||
frequency_penalty=0.0,
|
||||
repetition_penalty=1.0,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
top_k=-1,
|
||||
ignore_eos=False,
|
||||
min_new_tokens=0,
|
||||
max_new_tokens=self.config.response_length,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
elif is_validate:
|
||||
# TODO: try **
|
||||
kwargs = {
|
||||
"top_k": self.config.val_kwargs.top_k,
|
||||
"top_p": self.config.val_kwargs.top_p,
|
||||
"temperature": self.config.val_kwargs.temperature,
|
||||
"n": 1, # if validate, already repeat in ray_trainer
|
||||
}
|
||||
kwargs["max_new_tokens"] = max_new_tokens
|
||||
if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess
|
||||
kwargs["n"] = 1
|
||||
# users can customize different sampling_params at different run
|
||||
with self.update_sampling_params(**kwargs):
|
||||
output = await self._engine.async_generate(
|
||||
input_ids=generation_prompt_ids,
|
||||
sampling_params=self.sampling_params,
|
||||
return_logprob=False,
|
||||
)
|
||||
return output
|
||||
|
||||
async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest:
|
||||
if _req.tools is not None:
|
||||
tool_creation_coroutines = []
|
||||
for tool_schema in _req.tools:
|
||||
tool = self._tool_map[tool_schema.function.name]
|
||||
create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {})
|
||||
tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs))
|
||||
await asyncio.gather(*tool_creation_coroutines)
|
||||
|
||||
@GPUMemoryLogger(role="sglang async rollout", logger=logger)
|
||||
@torch.no_grad()
|
||||
def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto:
|
||||
# Async rollout with tools support
|
||||
do_sample = prompts.meta_info.get("do_sample", True)
|
||||
is_validate = prompts.meta_info.get("validate", False)
|
||||
tgt_device = prompts.batch["input_ids"].device
|
||||
if self._tp_rank == 0:
|
||||
req_list = self._preprocess_prompt_to_async_rollout_requests(
|
||||
prompts,
|
||||
n=1 if is_validate else self.config.n,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
output_req_list = loop.run_until_complete(
|
||||
asyncio.gather(
|
||||
*[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],
|
||||
)
|
||||
)
|
||||
sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))
|
||||
else:
|
||||
sorted_output_req_list = None
|
||||
|
||||
[sorted_output_req_list] = broadcast_pyobj(
|
||||
data=[sorted_output_req_list],
|
||||
rank=self._rank,
|
||||
dist_group=self._device_mesh_cpu["tp"].get_group(),
|
||||
src=self._device_mesh_cpu["tp"].mesh[0].item(),
|
||||
force_cpu_device=False,
|
||||
)
|
||||
# Construct the batch data
|
||||
prompt_ids, response_ids = [], []
|
||||
prompt_attention_mask, response_attention_mask = [], []
|
||||
prompt_position_ids, response_position_ids = [], []
|
||||
prompt_loss_mask, response_loss_mask = [], []
|
||||
messages = []
|
||||
reward_scores = []
|
||||
for req in sorted_output_req_list:
|
||||
assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed"
|
||||
assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of
|
||||
{len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}"""
|
||||
error_message_lines = [
|
||||
f"""Request {req.request_id} has input_ids length {len(req.input_ids)}
|
||||
greater than max_model_len {self.config.max_model_len}""",
|
||||
f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}",
|
||||
f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}",
|
||||
f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}",
|
||||
f"Messages: {req.messages}",
|
||||
f"Max model length: {req.max_model_len}",
|
||||
]
|
||||
error_message = "\n".join(error_message_lines)
|
||||
assert len(req.input_ids) <= self.config.max_model_len, error_message
|
||||
|
||||
prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device))
|
||||
response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device))
|
||||
if len(req.response_ids) > self.config.response_length:
|
||||
logger.warning(
|
||||
f"""{req.request_id=} has response_ids length {len(req.response_ids)}
|
||||
greater than max_response_len {self.config.response_length},\n{req=}"""
|
||||
)
|
||||
prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device))
|
||||
response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device))
|
||||
prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device))
|
||||
response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device))
|
||||
prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device))
|
||||
response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device))
|
||||
messages.append({"messages": req.messages})
|
||||
reward_scores.append(req.reward_scores)
|
||||
|
||||
prompt_ids = pad_sequence(prompt_ids, batch_first=True, padding_value=self.pad_token_id, padding_side="left")
|
||||
if prompt_ids.shape[1] < self.config.prompt_length:
|
||||
prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True)
|
||||
response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id)
|
||||
if response_ids.shape[1] < self.config.response_length:
|
||||
response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id)
|
||||
prompt_attention_mask = pad_sequence(prompt_attention_mask, batch_first=True, padding_value=0, padding_side="left")
|
||||
if prompt_attention_mask.shape[1] < self.config.prompt_length:
|
||||
prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True)
|
||||
response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0)
|
||||
if response_attention_mask.shape[1] < self.config.response_length:
|
||||
response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0)
|
||||
prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left")
|
||||
if prompt_position_ids.shape[1] < self.config.prompt_length:
|
||||
prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True)
|
||||
response_length = response_ids.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1)
|
||||
response_position_ids = prompt_position_ids[:, -1:] + delta_position_id
|
||||
prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left")
|
||||
if prompt_loss_mask.shape[1] < self.config.prompt_length:
|
||||
prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True)
|
||||
response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)
|
||||
if response_loss_mask.shape[1] < self.config.response_length:
|
||||
response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)
|
||||
|
||||
input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
|
||||
attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
|
||||
position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1)
|
||||
loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1)
|
||||
|
||||
# Construct the batch data
|
||||
batch = TensorDict(
|
||||
{
|
||||
"prompts": prompt_ids,
|
||||
"responses": response_ids,
|
||||
"input_ids": input_ids, # here input_ids become the whole sentences
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"loss_mask": loss_mask,
|
||||
},
|
||||
batch_size=len(sorted_output_req_list),
|
||||
)
|
||||
|
||||
# free cache engine
|
||||
if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0:
|
||||
self._engine.flush_cache()
|
||||
|
||||
return DataProto(batch=batch, non_tensor_batch={"messages": np.array(messages), "reward_scores": np.array(reward_scores)})
|
||||
|
||||
def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]:
|
||||
assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages"
|
||||
req_list = []
|
||||
for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]):
|
||||
for rollout_offset in range(n):
|
||||
if self._tool_schemas:
|
||||
_tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx]
|
||||
_tool_schemas = []
|
||||
for k in _tools_kwargs.keys():
|
||||
_tool_schemas.append(self._tool_map[k].get_openai_tool_schema())
|
||||
prompt_with_chat_template = self.tokenizer.apply_chat_template(
|
||||
conversation=raw_prompt,
|
||||
tools=[tool.model_dump() for tool in _tool_schemas],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_data = self.tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False)
|
||||
_input_ids = input_data["input_ids"][0].tolist()
|
||||
_attention_mask = input_data["attention_mask"][0].tolist()
|
||||
_position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist()
|
||||
if len(_input_ids) > self.config.prompt_length:
|
||||
logger.warning(
|
||||
"Prompt {} has length {} greater than max_prompt_len {}",
|
||||
data_idx,
|
||||
len(_input_ids),
|
||||
self.config.prompt_length,
|
||||
)
|
||||
_input_ids = _input_ids[: self.config.prompt_length]
|
||||
_attention_mask = _attention_mask[: self.config.prompt_length]
|
||||
_position_ids = _position_ids[: self.config.prompt_length]
|
||||
else:
|
||||
_input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx])
|
||||
_attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx])
|
||||
_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist()
|
||||
_tool_schemas = []
|
||||
_tools_kwargs = {}
|
||||
|
||||
req = AsyncRolloutRequest(
|
||||
batch_data_id=data_idx,
|
||||
rollout_offset=rollout_offset,
|
||||
request_id=str(uuid4()),
|
||||
state=AsyncRolloutRequestStateEnum.PENDING,
|
||||
messages=[Message.model_validate(msg) for msg in raw_prompt],
|
||||
tools=_tool_schemas,
|
||||
tools_kwargs=_tools_kwargs,
|
||||
input_ids=_input_ids,
|
||||
prompt_ids=_input_ids,
|
||||
response_ids=[],
|
||||
attention_mask=_attention_mask,
|
||||
prompt_attention_mask=_attention_mask,
|
||||
response_attention_mask=[],
|
||||
position_ids=_position_ids,
|
||||
prompt_position_ids=_position_ids,
|
||||
response_position_ids=[],
|
||||
loss_mask=[0] * len(_input_ids),
|
||||
prompt_loss_mask=[0] * len(_input_ids),
|
||||
response_loss_mask=[],
|
||||
reward_scores={},
|
||||
max_response_len=self.config.response_length,
|
||||
max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),
|
||||
)
|
||||
|
||||
error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}"
|
||||
assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message
|
||||
|
||||
req_list.append(req)
|
||||
|
||||
return req_list
|
||||
|
||||
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||
if method == "chat_completion":
|
||||
json_request = args[0]
|
||||
|
||||
formatted_messages = []
|
||||
for msg in json_request["messages"]:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
formatted_messages.append(f"{role}: {content}")
|
||||
prompt_str = "\n".join(formatted_messages)
|
||||
|
||||
sampling_params_dict = {
|
||||
"n": json_request.get("n", 1),
|
||||
"max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length),
|
||||
"temperature": json_request.get("temperature", 1.0),
|
||||
"top_p": json_request.get("top_p", 1.0),
|
||||
}
|
||||
output = None
|
||||
if self._tp_rank == 0:
|
||||
loop = asyncio.get_event_loop()
|
||||
output = loop.run_until_complete(
|
||||
self._engine.async_generate(
|
||||
prompt=prompt_str,
|
||||
sampling_params=sampling_params_dict,
|
||||
return_logprob=True,
|
||||
)
|
||||
)
|
||||
output = broadcast_pyobj(
|
||||
data=[output],
|
||||
rank=self._rank,
|
||||
dist_group=self._device_mesh_cpu["tp"].get_group(),
|
||||
src=self._device_mesh_cpu["tp"].mesh[0].item(),
|
||||
force_cpu_device=False,
|
||||
)
|
||||
|
||||
# only return value from master rank
|
||||
if self._tp_rank != 0:
|
||||
return None
|
||||
# build openai chat completion format
|
||||
choices = []
|
||||
id = None
|
||||
for i, content in enumerate(output):
|
||||
choices.append(
|
||||
{
|
||||
"index": i,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content["text"],
|
||||
},
|
||||
"finish_reason": content["meta_info"]["finish_reason"]["type"],
|
||||
}
|
||||
)
|
||||
id = content["meta_info"]["id"]
|
||||
|
||||
return {
|
||||
"id": "chatcmpl-" + id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": json_request.get("model", "sglang_model"),
|
||||
"choices": choices,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"not supported method : {method}")
|
||||
|
||||
# this function is left for uniform train-inference resharding
|
||||
|
||||
def resume(self):
|
||||
if not self.is_sleep:
|
||||
return
|
||||
self.sharding_manager.__enter__() # pylint: disable=C2801
|
||||
|
||||
self.is_sleep = False
|
||||
|
||||
# this function is left for uniform train-inference resharding
|
||||
def offload(self):
|
||||
if self.is_sleep:
|
||||
return
|
||||
|
||||
self.sharding_manager.__exit__(None, None, None)
|
||||
self.is_sleep = True
|
@ -1,3 +1,4 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -34,7 +34,9 @@ def broadcast_pyobj(
|
||||
The `rank` here refer to the source rank on global process group (regardless
|
||||
of dist_group argument).
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu")
|
||||
device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
||||
)
|
||||
|
||||
if rank == src:
|
||||
if len(data) == 0:
|
||||
@ -44,7 +46,9 @@ def broadcast_pyobj(
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
|
||||
tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device)
|
||||
tensor_data = torch.ByteTensor(
|
||||
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||
).to(device)
|
||||
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
|
||||
|
||||
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||
|
@ -1,18 +1,5 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -29,12 +16,10 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
from sglang.srt.utils import MultiprocessingSerializer
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
@ -66,7 +51,7 @@ class FSDPSGLangShardingManager(BaseShardingManager):
|
||||
def __init__(
|
||||
self,
|
||||
module: FSDP,
|
||||
inference_engine: Union[VerlEngine, Engine],
|
||||
inference_engine: Engine,
|
||||
model_config,
|
||||
full_params: bool = False,
|
||||
device_mesh: DeviceMesh = None,
|
||||
@ -144,44 +129,6 @@ class FSDPSGLangShardingManager(BaseShardingManager):
|
||||
self.gen_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.torch_random_states)
|
||||
|
||||
def update_weights(self, params):
|
||||
self.inference_engine.resume_memory_occupation()
|
||||
self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None)
|
||||
|
||||
def release_memory(self):
|
||||
self.inference_engine.release_memory_occupation()
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""All gather across tp group to make each rank has identical input."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
group = self.device_mesh["infer_tp"].get_group()
|
||||
|
||||
all_gather_data_proto(data=data, process_group=group)
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""Get chunk data of this tp rank since we do all gather in preprocess."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
return data.chunk(chunks=self.tp_size)[self.tp_rank]
|
||||
|
||||
|
||||
class FSDPAsyncSGLangShardingManager(FSDPSGLangShardingManager):
|
||||
def __init__(
|
||||
self,
|
||||
module: FSDP,
|
||||
inference_engine: Engine,
|
||||
model_config,
|
||||
full_params: bool = False,
|
||||
device_mesh: DeviceMesh = None,
|
||||
offload_param: bool = False,
|
||||
):
|
||||
super().__init__(module, inference_engine, model_config, full_params, device_mesh, offload_param)
|
||||
|
||||
def update_weights(self, params):
|
||||
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
||||
self.inference_engine.resume_memory_occupation()
|
||||
@ -218,3 +165,21 @@ class FSDPAsyncSGLangShardingManager(FSDPSGLangShardingManager):
|
||||
def release_memory(self):
|
||||
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
||||
self.inference_engine.release_memory_occupation()
|
||||
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""All gather across tp group to make each rank has identical input."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
# TODO: Current impl doesn't consider FSDP with torch micro-dp
|
||||
group = self.device_mesh["infer_tp"].get_group()
|
||||
|
||||
all_gather_data_proto(data=data, process_group=group)
|
||||
return data
|
||||
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
"""Get chunk data of this tp rank since we do all gather in preprocess."""
|
||||
if self.tp_size == 1:
|
||||
return data
|
||||
|
||||
return data.chunk(chunks=self.tp_size)[self.tp_rank]
|
||||
|
@ -22,7 +22,6 @@ import os
|
||||
|
||||
import torch
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||
from torch import nn
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
@ -50,7 +49,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
|
||||
def __init__(
|
||||
self,
|
||||
actor_module: nn.ModuleList,
|
||||
inference_engine: VerlEngine,
|
||||
inference_engine: Engine,
|
||||
model_config,
|
||||
transformer_config,
|
||||
layer_name_mapping,
|
||||
@ -113,50 +112,6 @@ class MegatronSGLangShardingManager(BaseShardingManager):
|
||||
self.gen_random_states = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(self.torch_random_states)
|
||||
|
||||
def update_weights(self, params):
|
||||
self.inference_engine.resume_memory_occupation()
|
||||
self.inference_engine.update_weights_from_tensor(params, load_format=None)
|
||||
|
||||
def release_memory(self):
|
||||
self.inference_engine.release_memory_occupation()
|
||||
|
||||
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
|
||||
if self.infer_tp_size == 1:
|
||||
return data
|
||||
all_gather_data_proto(data, self.device_mesh["tp"].get_group())
|
||||
return data
|
||||
|
||||
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
|
||||
if self.infer_tp_size == 1:
|
||||
return data
|
||||
return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()]
|
||||
|
||||
|
||||
class MegatronAsyncSGLangShardingManager(MegatronSGLangShardingManager):
|
||||
def __init__(
|
||||
self,
|
||||
actor_module: nn.ModuleList,
|
||||
inference_engine: Engine,
|
||||
model_config,
|
||||
transformer_config,
|
||||
layer_name_mapping,
|
||||
weight_converter,
|
||||
device_mesh: DeviceMesh = None,
|
||||
):
|
||||
super().__init__(
|
||||
actor_module,
|
||||
inference_engine,
|
||||
model_config,
|
||||
transformer_config,
|
||||
layer_name_mapping,
|
||||
weight_converter,
|
||||
device_mesh,
|
||||
)
|
||||
|
||||
def update_weights(self, params):
|
||||
if self.device_mesh["tp"].get_local_rank() == 0:
|
||||
self.inference_engine.resume_memory_occupation()
|
||||
@ -184,3 +139,18 @@ class MegatronAsyncSGLangShardingManager(MegatronSGLangShardingManager):
|
||||
def release_memory(self):
|
||||
if self.device_mesh["tp"].get_local_rank() == 0:
|
||||
self.inference_engine.release_memory_occupation()
|
||||
|
||||
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
|
||||
def preprocess_data(self, data: DataProto) -> DataProto:
|
||||
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
|
||||
if self.infer_tp_size == 1:
|
||||
return data
|
||||
all_gather_data_proto(data, self.device_mesh["tp"].get_group())
|
||||
return data
|
||||
|
||||
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
|
||||
def postprocess_data(self, data: DataProto) -> DataProto:
|
||||
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
|
||||
if self.infer_tp_size == 1:
|
||||
return data
|
||||
return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()]
|
||||
|
Reference in New Issue
Block a user