Merge branch 'volcengine:main' into recipe/async_training

This commit is contained in:
arron
2025-07-16 19:24:55 +08:00
committed by GitHub
66 changed files with 2560 additions and 311 deletions

2
.github/CODEOWNERS vendored
View File

@ -8,7 +8,7 @@
/third_party/sglang @zhaochenyang20 @SwordFaith
/third_party/vllm @PeterSH6 @wuxibin89
/verl/single_controller @zw0610 @wuxibin89
/verl/single_controller @zw0610 @wuxibin89 @hongpeng-guo
/verl/trainer @eric-haibin-lin @vermouth1992 @tongyx361 @PeterSH6
/verl/workers/rollout/vllm_rollout @wuxibin89 @PeterSH6 @chenhaiq
/verl/workers/rollout/sglang_rollout @zhaochenyang20 @SwordFaith @chenhaiq

View File

@ -131,6 +131,10 @@ jobs:
run: |
ray stop --force
python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat --use_cpu_initialization
- name: Running distributed Huggingface to Megatron dist_ckpt CPU converter (Qwen/Qwen1.5-MoE-A2.7B-Chat)
run: |
ray stop --force
torchrun --nproc_per_node 8 --nnodes 1 scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat_dist --use_cpu_initialization
- name: clean up
run: |
rm -rf checkpoints

View File

@ -139,6 +139,10 @@ jobs:
exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal"
python -m verl.model_merger test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface
python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface
- name: Test Megatron distributed checkpoints merging function (DeepSeek)
run: |
exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal"
torchrun --nproc_per_node 4 --nnodes 1 -m verl.model_merger merge --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --target_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/hf_model
- name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek)
run: |
ray stop --force

View File

@ -14,7 +14,7 @@ The first two types of images are hosted on dockerhub [verlai/verl](https://hub.
## Base Image
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4-te2.3``. The installed package versions can be found from tags, and the Dockerfile can be found in ``verl[version]-[packages]/Dockerfile.base``.
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``verl[version]-[packages]/Dockerfile.base``.
The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions.
@ -76,4 +76,4 @@ pip3 install --no-deps -e .
git clone https://github.com/volcengine/verl && cd verl
pip3 install -e .[vllm]
pip3 install -e .[sglang]
```
```

View File

@ -99,6 +99,16 @@ Example usage for merging Megatron checkpoints:
--local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \
--target_dir /path/to/merged_hf_model
Example usage for distributed merging Megatron checkpoints:
.. code:: bash
torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge \
--backend megatron \
--tie-word-embedding \
--local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \
--target_dir /path/to/merged_hf_model
Example usage for merging FSDP checkpoints:
.. code:: bash
@ -145,6 +155,15 @@ Example command to convert the model is as follows:
--use_cpu_initialization # Only work for MoE models
Example command to distributed convert the huge model like deepseekv3 671B is as follows:
.. code:: bash
torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} scripts/converter_hf_to_mcore.py \
--hf_model_path deepseek-ai/DeepSeek-V3 \
--output_path /mnt/disk/deepseek-ai/DeepSeek-V3 \
--use_cpu_initialization # Only work for MoE models
Original Checkpoint Utils
-------------------------

View File

@ -60,7 +60,7 @@ Ulysses Utilities
--------------------
.. automodule:: verl.utils.ulysses
:members: gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
:members: gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
FSDP Utilities
------------------

View File

@ -33,6 +33,7 @@ verl is fast with:
start/multinode
start/ray_debug_tutorial
start/more_resources
start/agentic_rl
.. toctree::
:maxdepth: 2

125
docs/start/agentic_rl.rst Normal file
View File

@ -0,0 +1,125 @@
Agentic RL Training
===================
Last updated: 07/15/2025.
Overview
----------
The goal of Agentic RL is to improve the performance of backend models from reinforcement learning to the Agent. During the training process, a series of features are developed:
1. Server-based asynchronous rollout
2. Multi-turn conversations and tool calls
3. LangGraph-based Agent
This document explains the system principles and usage involved to help users implement Agentic RL.
Server-based Asynchronous Rollout
---------------------------------
Since Agents need to interact with the environment through various tool calls, in order to avoid GPU idling while waiting for tool call return results, an asyncio based co-routing mechanism is utilized to execute each rollout requests asynchronously, thereby improving training performance. To support asynchronous rollout, the inference engine (server) and the agent (client) are architecturally separated, implementing a server-based system with the following objectives:
1. Enabling load balancing mechanisms to balance loads across multiple GPUs and reduce the impact of long-tail requests on performance. For this purpose, scheduling capabilities in stream mode (recipe\stream_mode) are implemented as a recipe.
2. Preventing agent specific features such as tracing from affecting the inference engine.
System Architecture
~~~~~~~~~~~~~~~~~~~
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop.png?raw=true
System Components
~~~~~~~~~~~~~~~~~
+--------------------------+----------------------------------------------------------------------------+
| Component | Role |
+==========================+============================================================================+
| AgentLoop | Client, implements Agent functions |
+--------------------------+----------------------------------------------------------------------------+
| AsyncLLMServerManager | Inference gateway, provides generate interface for AgentLoop |
+--------------------------+----------------------------------------------------------------------------+
| AsyncServer | Server, each instance is connected to one DP group of the inference engine |
+--------------------------+----------------------------------------------------------------------------+
**"generate" Interface**
The "generate" function based on ray actor is used between the Client and Server instead of the standard chat completion API. This is because the conversion between tokens and text can be irreversible. For example, the token converted from "<think>" will be different from that generated by the LLM. During the training phase, it is necessary to strictly use the tokens generated by LLM inference to avoid inaccurate in computing advantage, which may affect model performance. Having the Server provide a token-based API helps the Client maintain the relationship between the text generated by tool calls and the tokens returned by the LLM, so as to output correct tokens for training.
**Inference Engine Adaptation**
AsyncServer uniformly provides a generate function to the upper layer, with separate implementations for SGLang and vLLM to hide underlying differences:
1. The SGLang AsyncServer uses the async_generate interface of the SGLang engine, which is located on the first GPU of each TP group. Therefore, AsyncServer needs to remotely call async_generate through ray actor.
2. The vLLM AsyncServer uses the generate interface of the vLLM engine, which can communicate with the GPUs in the TP group through ZMQ and can be directly called in AsyncServer.
Usage Example
~~~~~~~~~~~~~
Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints.
This example uses the sglang inference engine by default, and you can also modify rollout_name to use vllm.
.. code-block:: bash
bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
Multi-turn Conversations and Tool Calls
---------------------------------------
Follow :doc:`Multi-turn Rollout Support<../sglang_multiturn/multiturn>` to prepare tool and configuration files.
The Tool Agent Loop has an additional requirement: adding an "agent_name" field to the dataset. During rollout, it will choose to use tool_agent_loop or single_turn_agent (default) based on this field.
Usage Example
~~~~~~~~~~~~~
.. code-block:: bash
# install mlflow to view toolcall and llm trace
pip install mlflow
# This will download and preprocess the GSM8K dataset into ~/data/gsm8k/ and add the "agent_name" field.
bash examples/data_preprocess/gsm8k_tool_agent_loop.py
# Start training with tool calls and enabled mlflow based trace helping to debug the rollout details
bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh
# When training is done, start a mlflow server to view trace
mlflow ui -h 0.0.0.0 -p 5000 --backend-store-uri sqlite:////tmp/mlruns.db
# then you can open http://<your ip address>:5000 from browser to view trace
Note: During training, because the model may sometimes fail to generate correct toolcall tags, an error message "Failed to decode tool call" will be output to the console, which does not indicate an abnormality in training.
Follow :doc:`Rollout trace<../advance/rollout_trace>` to known more about trace feature.
Agent Framework
---------------
System Architecture
~~~~~~~~~~~~~~~~~~~
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true
System Components
~~~~~~~~~~~~~~~~~
+--------------------------+-----------------------------------------------------------------------------------------------+
| Component | Role |
+==========================+===============================================================================================+
| ChatModel | LLM object of LangChain, used to adapt to the “generate” api provided by AsyncLLMServerManager|
+--------------------------+-----------------------------------------------------------------------------------------------+
| RectAgentLoop | Agent adaptation layer, which by default supports a naive LangGraph Agentic. |
| | New classes can be derived to support user-defined Agents, and the run function needs to be |
| | implemented to complete Agent calls. |
+--------------------------+-----------------------------------------------------------------------------------------------+
| AsyncServer | Server, each instance is connected to one DP group of the inference engine. |
+--------------------------+-----------------------------------------------------------------------------------------------+
Follow doc "recipe/langgraph_agent/example/README.md" for more details.

View File

@ -52,7 +52,7 @@ The first two types of images are hosted on dockerhub `verlai/verl <https://hub.
Base Image
::::::::::
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4-te2.3``. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``.
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``.
The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions. From verl0.5, images are built with `Deep-EP <https://github.com/deepseek-ai/DeepEP>`_ for efficient EP communication.
@ -255,7 +255,7 @@ If you encounter issues about package versions during running verl, please updat
Install with AMD GPUs - ROCM kernel support
------------------------------------------------------------------
When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it.
When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it.
If you encounter any issues in using AMD GPUs running verl, feel free to contact me - `Yusheng Su <https://yushengsu-thu.github.io/>`_.
Find the docker for AMD ROCm: `docker/Dockerfile.rocm <https://github.com/volcengine/verl/blob/main/docker/Dockerfile.rocm>`_
@ -336,6 +336,6 @@ Launch the container
/bin/bash
If you do not want to root mode and require assign yourself as the user,
Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script.
Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script.
verl with AMD GPUs currently supports FSDP as the training engine, vLLM and SGLang as the inference engine. We will support Megatron in the future.

View File

@ -0,0 +1,117 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the GSM8k dataset to parquet format
"""
import argparse
import os
import re
import datasets
from verl.utils.hdfs_io import copy, makedirs
def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/gsm8k")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
data_source = "openai/gsm8k"
dataset = datasets.load_dataset(data_source, "main")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
instruction_following = "Let's think step by step and output the final answer after `####`."
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question_raw = example.pop("question")
question = question_raw + " " + instruction_following
answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
data = {
"data_source": data_source,
"agent_name": "tool_agent",
"prompt": [
{
"role": "system",
"content": (
"You are a math expert. You are given a question and you need to solve it step by step. "
"Reasoning step by step before any tool call. "
"You should use the `calc_gsm8k_reward` tool after step by step solving the question, "
"before generate final answer at least once and refine your answer if necessary. "
"Put your final answer in the format of `#### <answer>`."
),
},
{
"role": "user",
"content": question,
},
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw,
"need_tools_kwargs": True,
"tools_kwargs": {
"calc_gsm8k_reward": {
"create_kwargs": {"ground_truth": solution},
# "execute_kwargs": {},
# "calc_reward_kwargs": {},
# "release_kwargs": {},
},
},
"interaction_kwargs": {
"query": question,
"ground_truth": solution,
},
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)

View File

@ -19,7 +19,7 @@ python3 -m verl.trainer.main_ppo \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \

View File

@ -18,7 +18,7 @@ python3 -m verl.trainer.main_ppo \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \

View File

@ -25,7 +25,7 @@ python3 -m verl.trainer.main_ppo \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \

View File

@ -49,5 +49,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
trainer.total_epochs=15 $@
trainer.total_epochs=15 \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@

View File

@ -0,0 +1,57 @@
# run on 8xH100
# make sure your current working directory is the root of the project
set -x
ulimit -n 65535
PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"
python3 -m verl.trainer.main_ppo \
--config-path="$CONFIG_PATH" \
--config-name='gsm8k_multiturn_grpo' \
algorithm.adv_estimator=grpo \
data.train_batch_size=256 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.rollout.trace.backend=mlflow \
actor_rollout_ref.rollout.trace.token2text=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","mlflow"]' \
trainer.project_name='gsm8k_tool-agent' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-tool-agent-verify-n16' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=20 \
trainer.total_training_steps=2 \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
trainer.total_epochs=15 $@

View File

@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,357 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Ref: https://python.langchain.com/docs/how_to/custom_chat_model/
"""
import asyncio
import json
import logging
import os
import uuid
from typing import Any, Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_openai_messages,
)
from langchain_core.messages.tool import InvalidToolCall, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import StructuredTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import Field
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
from verl.experimental.agent_loop.tool_parser import ToolParser
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class MaxTokenExceededError(Exception):
"""Indicate that history chat messages + tool message exceeds LLM max_tokens."""
pass
class ChatModel(BaseChatModel):
model_name: str = Field(alias="model")
"""The name of the model"""
client: AsyncLLMServerManager
"""AsyncLLM server manager"""
tokenizer: Any
"""Tokenizer for the model"""
max_tokens: int
"""Max tokens to generate"""
tool_parser: str = "hermes"
"""Tool parser for the model"""
max_parallel_calls: int = 1
"""Max parallel tool calls"""
temperature: float = 1.0
"""Temperature for sampling"""
top_p: float = 1.0
"""Top p for sampling"""
repetition_penalty: float = 1.0
"""Repetition penalty for sampling"""
def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model.
Args:
tools: Sequence of tools to bind to the model.
Returns:
A Runnable that returns a message.
"""
formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]
# used to remove system prompt prefix when encoding tool response
system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
kwargs["system_prompt"] = system_prompt
return self.bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: dict | type,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict | BaseChatModel]:
"""Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/"""
raise NotImplementedError
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
async def _agenerate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate chat completion message.
Args:
messages (list[BaseMessage]): List of list of messages.
stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings. Defaults to None.
Returns:
ChatResult: Chat result.
"""
request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)
sampling_params = {
"temperature": self.temperature,
"top_p": self.top_p,
"repetition_penalty": self.repetition_penalty,
}
if "sampling_params" in kwargs:
sampling_params.update(kwargs["sampling_params"])
response_ids = await self.client.generate(
request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
)
message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return self.model_name
async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:
"""Preprocess messages for chat completion.
To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out
instead of messages list.
But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory
(prompt_ids, response_mask) in lastest AIMessage.response_metadata.
1. Encode ToolMessage to token ids.
2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.
3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.
Ref: https://python.langchain.com/docs/concepts/chat_history/
Args:
messages (list[BaseMessage]): List of messages.
Returns:
tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.
"""
# messages: [system], human, ai, human|tool, ai, human|tool, ...
assert messages[-1].type in ["human", "tool"], (
f"Last message must be human or tool, but got {messages[-1].type}"
)
loop = asyncio.get_running_loop()
# Case 1: initial chat completion: [system], human
if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"):
prompt_ids = await loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
convert_to_openai_messages(messages),
tools=kwargs.get("tools"),
add_generation_prompt=True,
tokenize=True,
),
)
return str(uuid.uuid4()), prompt_ids, []
# Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...
for i in range(len(messages) - 1, -1, -1):
if messages[i].type == "ai":
break
assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata"
assert "response_mask" in messages[i].response_metadata, (
"Last message must have response_mask in response_metadata"
)
# encode tool response
tool_responses = convert_to_openai_messages(messages[i + 1 :])
tool_response_ids = await loop.run_in_executor(
None,
lambda messages=tool_responses: self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True
),
)
tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
# stop generation if response length exceeds max response length
if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded")
# append tool response to prompt
request_id = messages[i].response_metadata.pop("request_id")
prompt_ids = messages[i].response_metadata.pop("prompt_ids")
response_mask = messages[i].response_metadata.pop("response_mask")
prompt_ids += tool_response_ids
response_mask += [0] * len(tool_response_ids)
return request_id, prompt_ids, response_mask
async def _postprocess(
self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any
) -> AIMessage:
"""Postprocess response_ids when chat completion is done.
1. Decode response_ids, parse tool calls to AIMessage.
2. Append response_ids to prompt_ids, and append 1 to response_mask.
3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.
Args:
request_id (str): Unique request id.
prompt_ids (list[int]): Input prompt token ids in this chat completion.
response_mask (list[int]): Response mask before this chat completion.
response_ids (list[int]): LLM generated token ids in this chat completion.
Returns:
AIMessage: Postprocessed message.
"""
prompt_ids += response_ids
response_mask += [1] * len(response_ids)
tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)
content, function_calls = await tool_parser.extract_tool_calls(response_ids)
tool_calls, invalid_tool_calls = [], []
for function_call in function_calls:
try:
args = json.loads(function_call.arguments)
if not isinstance(args, dict):
raise json.JSONDecodeError(f"Invalid json tool arguments: {args}")
tool_call = ToolCall(
args=args,
name=function_call.name,
id=str(uuid.uuid4()),
)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
logger.warning(f"Invalid json tool arguments: {e}")
tool_call = InvalidToolCall(
args=function_call.arguments,
name=function_call.name,
error=f"Invalid json tool arguments: {e}",
)
invalid_tool_calls.append(tool_call)
message = AIMessage(
content=content,
tool_calls=tool_calls[: self.max_parallel_calls],
invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],
response_metadata={
"request_id": request_id,
"prompt_ids": prompt_ids,
"response_mask": response_mask,
},
)
return message
class TruncateStructuredTool(StructuredTool):
"""Structured tool with response truncation."""
tool_response_truncate_side: str
"""truncate side of tool response: left, middle, right"""
max_tool_response_length: int
"""max length of tool response"""
async def _arun(
self,
*args: Any,
config: RunnableConfig,
**kwargs: Any,
) -> Any:
tool_response = await super()._arun(*args, config=config, **kwargs)
tool_response = str(tool_response)
if len(tool_response) > self.max_tool_response_length:
if self.tool_response_truncate_side == "left":
tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
elif self.tool_response_truncate_side == "right":
tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
else:
length = self.max_tool_response_length // 2
tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]
return tool_response
def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:
"""Convert messages to AgentLoopOutput.
Args:
messages (List[BaseMessage]): List of messages, last message must be assistant
with response_metadata containing `prompt_ids` and `response_mask`.
response_length (int): Max length of response.
Returns:
AgentLoopOutput: agent loop output trajectory used for training.
"""
# skip last tool calls
for i in range(len(messages) - 1, -1, -1):
if messages[i].type != "tool":
break
last_message = messages[i]
assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}"
assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata"
assert "response_mask" in last_message.response_metadata, (
"Last message must have response_mask in response_metadata"
)
num_turns = 0
for i in range(len(messages)):
if messages[i].type == "system":
continue
# parallel tool calls are in single turn
if i == 0 or messages[i].type != messages[i - 1].type:
num_turns += 1
prompt_ids = last_message.response_metadata["prompt_ids"]
response_mask = last_message.response_metadata["response_mask"]
response_ids = prompt_ids[-len(response_mask) :]
prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[:response_length],
response_mask=response_mask[:response_length],
num_turns=num_turns,
metrics={},
)
return output

View File

@ -0,0 +1,111 @@
# MathExpression: LangGraph Agent Example
MathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/).
### Define react agent with tool
Firstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`:
```python
@tool(parse_docstring=True)
def calculate(a: int, b: int, operand: str) -> int:
"""
Compute the results using operand with two integers
Args:
a: the first operand
b: the second operand
operand: '+' or '-' or '*' or '@'
"""
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
if operand == "@":
return 3 * a - 2 * b
return eval(f"{a} {operand} {b}")
```
Without calling `calculate`, ReactAgent is impossible to evaluate math expression correctly.
Then, we can equip ReactAgent with `calculate` tool:
```python
class MathExpressionReactAgentLoop(ReactAgentLoop):
@classmethod
def init_class(cls, config, tokenizer):
cls.tools = [calculate]
super().init_class(config, tokenizer)
```
We can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class.
```yaml
- name: math_expression
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
```
### Prepare dataset
Now, let's prepare two small datasets for training and evaluation:
```bash
python recipe/langgraph_agent/example/create_dataset.py
```
Note that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the
agent loop class.
| prompt | reward_model | agent_name |
|--------------------------------------|------------------------------|-----------------|
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
Generated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions.
```
(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6
(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6
9 @ 4
((2 @ 2) @ (3 @ 3)) @ 4
```
### Training
Hook all these up and start training:
```bash
bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log
```
After total 39 steps, model should achieve 100% accuray on test dataset:
- val-aux/lighteval/MATH/reward: 1.0
- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns.
We can see model can solve math expression by calling `calculate` tool:
```
system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
We define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \boxed{} only in the last turn. Now answer the following questions:
Compute 5 @ (3 @ 6 @ 8)
assistant
Sure, let's break this down step by step. We need to compute the expression \(5 @ (3 @ 6 @ 8)\). First, we need to compute the inner expression \(3 @ 6 @ 8\).
Let's compute \(3 @ 6\) first using the external tool.
<tool_call>
{"name": "calculate", "arguments": {"a": 3, "b": 6, "operand": "@"}}
</tool_call>user
<tool_response>
-3
</tool_response>
assistant
We have computed that \(3 @ 6 = -3\). Now we need to compute \(-3 @ 8\).
Let's use the external tool to compute this.
<tool_call>
{"name": "calculate", "arguments": {"a": -3, "b": 8, "operand": "@"}}
</tool_call>user
<tool_response>
-25
</tool_response>
assistant
We have computed that \(-3 @ 8 = -25\). Now we can put this result into the original expression to get the final result of \(5 @ (3 @ 6 @ 8)\), which is \(5 @ (-25)\).
Let's compute \(5 @ (-25)\) using the external tool.
<tool_call>
{"name": "calculate", "arguments": {"a": 5, "b": -25, "operand": "@"}}
</tool_call>user
<tool_response>
65
</tool_response>
assistant
The final result of \(5 @ (3 @ 6 @ 8)\) is \(\boxed{65}\).
```

View File

@ -0,0 +1,2 @@
- name: math_expression
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop

View File

@ -0,0 +1,277 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create dataset for calculator
"""
import random
import pandas as pd
def generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2):
"""
Generate a random mathematical expression with operators +, -, *, /, and parentheses.
Args:
min_terms (int): Minimum number of terms in the expression.
max_terms (int): Maximum number of terms in the expression.
max_number (int): Maximum value for numbers in the expression.
allow_decimals (bool): Whether to allow decimal numbers.
max_depth (int): Maximum nesting depth for parentheses.
Returns:
str: A valid mathematical expression as a string.
"""
def generate_number():
"""Generate a random number (integer or float)."""
assert min_number < max_number
num = random.uniform(min_number, max_number)
if not allow_decimals:
num = int(num)
else:
num = round(num, random.randint(0, 2)) # Round to 0-2 decimal places
return str(num)
def generate_term(depth=0):
"""Generate a term (number or parenthesized expression)."""
if depth < max_depth and random.random() < 0.5: # 50% chance to add parentheses
expr = generate_expression(depth + 1)
return f"({expr})"
else:
return generate_number()
def generate_expression(depth=0):
"""Generate a full expression with multiple terms and operators."""
num_terms = random.randint(min_terms, max_terms)
terms = [generate_term(depth) for _ in range(num_terms)]
# Randomly select operators
operators = ["+", "-", "*", "/", "@"]
expr = terms[0]
for i in range(1, num_terms):
# Bias towards + and - for readability
op = random.choices(
operators,
weights=[0, 0, 0, 0, 1], # + and - are 1.5x more likely than * and /
)[0]
expr += f" {op} " + terms[i]
return expr
return generate_expression()
def test():
# Example 1: Basic integer expression
print(generate_math_expression())
# Output: (3 + 7) * 2 - 5
# Example 2: Expression with decimals
print(generate_math_expression(allow_decimals=True))
# Output: 4.5 / (2.1 + 3.7) - 1.2
# Example 3: More complex expression with higher depth
print(generate_math_expression(max_terms=6, max_depth=3))
# Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4
# Example 4: Simplified expression
print(generate_math_expression(min_terms=2, max_terms=3, max_number=5))
# Output: 4 - 2 * 3
def calculate(expression: str) -> float:
"""
Evaluate a mathematical expression with +, -, *, /, @, and parentheses.
The @ operator is defined as: a @ b = 3a - 2b.
Args:
expression (str): Input mathematical expression (e.g., "3@2+4").
Returns:
float: Result of the evaluated expression.
Raises:
ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero).
"""
def tokenize(s: str) -> list:
"""Convert the input string into tokens (numbers, operators, parentheses)."""
tokens = []
i = 0
while i < len(s):
if s[i].isdigit() or s[i] == ".":
# Parse number (integer or float)
j = i
while j < len(s) and (s[j].isdigit() or s[j] == "."):
j += 1
tokens.append(s[i:j])
i = j
elif s[i] in "+-*/@()":
# Operator or parenthesis
tokens.append(s[i])
i += 1
elif s[i].isspace():
# Skip whitespace
i += 1
else:
raise ValueError(f"Invalid character: {s[i]}")
return tokens
def infix_to_postfix(tokens: list) -> list:
"""Convert infix notation to postfix notation (Reverse Polish Notation)."""
output = []
stack = []
# Higher precedence for @ (between * and +)
precedence = {"@": 3, "*": 2, "/": 2, "+": 1, "-": 1}
for token in tokens:
if token.isdigit() or "." in token:
output.append(token)
elif token == "(":
stack.append(token)
elif token == ")":
while stack and stack[-1] != "(":
output.append(stack.pop())
if not stack or stack[-1] != "(":
raise ValueError("Mismatched parentheses")
stack.pop() # Discard '('
else: # Operator
while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence.get(token, 0):
output.append(stack.pop())
stack.append(token)
# Pop remaining operators
while stack:
if stack[-1] in "()":
raise ValueError("Mismatched parentheses")
output.append(stack.pop())
return output
def evaluate_postfix(postfix: list) -> float:
"""Evaluate postfix expression using a stack."""
stack = []
for token in postfix:
if token.isdigit() or "." in token:
stack.append(float(token))
else:
if len(stack) < 2:
raise ValueError("Invalid expression")
b = stack.pop()
a = stack.pop()
if token == "+":
res = a + b
elif token == "-":
res = a - b
elif token == "*":
res = a * b
elif token == "/":
if b == 0:
raise ValueError("Division by zero")
res = a / b
elif token == "@":
res = 3 * a - 2 * b # Custom @ operator implementation
else:
raise ValueError(f"Invalid operator: {token}")
stack.append(res)
if len(stack) != 1:
raise ValueError("Invalid expression")
return stack[0]
# Remove spaces and validate parentheses
expression = expression.replace(" ", "")
if expression.count("(") != expression.count(")"):
raise ValueError("Mismatched parentheses")
tokens = tokenize(expression)
postfix = infix_to_postfix(tokens)
result = evaluate_postfix(postfix)
# Convert integers to integer representation
if result.is_integer():
return int(result)
return result
def generate_data(total_num_dataset, split):
rl_dataset = {
"prompt": [],
"data_source": [],
"ability": [],
"reward_model": [],
"extra_info": [],
"agent_name": [],
}
for idx in range(total_num_dataset):
while True:
try:
expression: str = generate_math_expression(
min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1
)
num_plus = expression.count("+")
num_minus = expression.count("-")
num_mul = expression.count("*")
num_star = expression.count("@")
answer = str(calculate(expression))
# answer = str(eval(expression))
break
except Exception as e:
print(e)
continue
num_tool_calls = num_plus + num_minus + num_mul + num_star
prompt = (
f"We define a new math operator @, where you can only call an external tool to compute. "
f"Please put your final answer inside \\boxed{{}} only in the last turn. Now answer the "
f"following questions:\nCompute {expression}"
)
prompt_with_template = [
{
"role": "user",
"content": prompt,
}
]
rl_dataset["prompt"].append(prompt_with_template)
rl_dataset["data_source"].append("lighteval/MATH")
rl_dataset["ability"].append("math")
rl_dataset["reward_model"].append({"style": "lighteval/MATH", "ground_truth": answer})
rl_dataset["extra_info"].append(
{"index": idx, "expression": expression, "split": split, "expected_tool_calls": num_tool_calls}
)
rl_dataset["agent_name"].append("math_expression")
rl_dataset = pd.DataFrame(data=rl_dataset)
return rl_dataset
if __name__ == "__main__":
# print(calculate("3@2")) # Output: 5 (3*3 - 2*2)
# print(calculate("3@2+4")) # Output: 9 (5 + 4)
# print(calculate("3*(4@2)")) # Output: 24 (3 * 8)
# print(calculate("(5@3)*2")) # Output: 18 (9 * 2)
train_dataset = generate_data(total_num_dataset=5000, split="train")
test_dataset = generate_data(total_num_dataset=500, split="test")
train_dataset.to_parquet("train.parquet")
test_dataset.to_parquet("test.parquet")

View File

@ -0,0 +1,39 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from langchain_core.tools import tool
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
@tool(parse_docstring=True)
def calculate(a: int, b: int, operand: str) -> int:
"""
Compute the results using operand with two integers
Args:
a: the first operand
b: the second operand
operand: '+' or '-' or '*' or '@'
"""
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
if operand == "@":
return 3 * a - 2 * b
return eval(f"{a} {operand} {b}")
class MathExpressionReactAgentLoop(ReactAgentLoop):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
cls.tools = [calculate]
super().init_class(config, tokenizer)

View File

@ -0,0 +1,99 @@
set -x
# ================= data/model/tool =================
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}
model_path=$DATA_ROOT/model/Qwen2.5-3B-Instruct
train_files=$DATA_ROOT/dataset/math_expression_tool/train.parquet
test_files=$DATA_ROOT/dataset/math_expression_tool/test.parquet
# agent
agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml
# wandb
project_name=math_expression_tool
experiment_name=qwen2.5-3b
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
# ================= algorithm =================
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_turns=8
max_prompt_length=1024
max_response_length=2048
actor_lr=1e-6
train_batch_size=128
ppo_mini_batch_size=16
n_resp_per_prompt=8
n_resp_per_prompt_val=1
# ================= perfomance =================
infer_tp=2 # vllm
train_sp=4 # train
offload=True
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=$adv_estimator \
algorithm.use_kl_in_reward=$use_kl_in_reward \
algorithm.kl_ctrl.kl_coef=$kl_coef \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=True \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=$max_prompt_length \
data.max_response_length=$max_response_length \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.optim.lr=$actor_lr \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
trainer.logger=['console','wandb'] \
trainer.project_name=$project_name \
trainer.experiment_name=$experiment_name \
trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \
trainer.val_before_train=True \
trainer.log_val_generations=50 \
trainer.nnodes=$ARNOLD_WORKER_NUM \
trainer.save_freq=-1 \
trainer.default_local_dir=$default_local_dir \
trainer.test_freq=5 \
trainer.total_epochs=1 $@

View File

@ -0,0 +1,133 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
LangGraph React Agent Loop.
This implementation is exact same as `ToolAgentLoop`.
Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/
"""
from typing import Any, Literal
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from recipe.langgraph_agent.chat_model import (
ChatModel,
MaxTokenExceededError,
convert_to_agent_output,
)
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
async def call_model(state: MessagesState, config: RunnableConfig):
model = config["configurable"]["model"]
sampling_params = config["configurable"]["sampling_params"]
try:
message = await model.ainvoke(state["messages"], sampling_params=sampling_params)
return {"messages": [message]}
except MaxTokenExceededError:
# last message is ToolMessage
return {"messages": []}
def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]:
max_assistant_turns = config["configurable"]["max_assistant_turns"]
num_assistant_turns = 0
for message in state["messages"]:
if message.type == "ai":
num_assistant_turns += 1
last_message = state["messages"][-1]
# LLM call failed, e.g: max response length exceeded
if last_message.type == "tool":
return END
# max assistant turns exceeded
if max_assistant_turns and num_assistant_turns >= max_assistant_turns:
return END
# no tool calls
if not last_message.tool_calls:
return END
return "tools"
class ReactAgentLoop(AgentLoopBase):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
if cls._class_initialized:
return
cls._class_initialized = True
print("Performing class-level ReactAgentLoop initialization")
# build graph
cls.graph = cls.build_graph()
@classmethod
def build_graph(cls) -> StateGraph:
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode(cls.tools))
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
END: END,
},
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
return graph
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
model_path = self.config.actor_rollout_ref.model.path
model_name = "/".join(model_path.split("/")[-2:])
rollout = self.config.actor_rollout_ref.rollout
model = ChatModel(
model=model_name,
client=self.server_manager,
tokenizer=self.tokenizer,
max_tokens=rollout.response_length,
max_parallel_calls=rollout.multi_turn.max_parallel_calls,
tool_parser=rollout.multi_turn.format,
)
model = model.bind_tools(self.tools, tool_choice="any")
config = {
"configurable": {
"model": model,
"sampling_params": sampling_params,
"max_user_turns": rollout.multi_turn.max_user_turns,
"max_assistant_turns": rollout.multi_turn.max_assistant_turns,
}
}
# TODO: how to handle multiple trajectories in an graph invocation?
# Each graph node may has its own LLM calls and state, e.g:
# https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart
state = await self.graph.ainvoke(input={"messages": messages}, config=config)
output = convert_to_agent_output(state["messages"], rollout.response_length)
return output

View File

@ -0,0 +1,199 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import numpy as np
import pytest
import ray
from langchain_core.tools import tool
from omegaconf import DictConfig
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
from verl.protocol import DataProto
from verl.utils import hf_tokenizer
@pytest.fixture
def init_config() -> DictConfig:
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose(config_name="ppo_trainer")
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
config.actor_rollout_ref.model.path = model_path
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
config.actor_rollout_ref.rollout.mode = "async"
config.actor_rollout_ref.rollout.prompt_length = 4096
config.actor_rollout_ref.rollout.response_length = 4096
config.actor_rollout_ref.rollout.n = 4
config.actor_rollout_ref.rollout.agent.num_workers = 2
# test sleep/wake_up with fsdp offload
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
return config
@tool(parse_docstring=True)
def get_current_temperature(location: str, unit: str = "celsius"):
"""Get current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, State, Country".
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
Returns:
the temperature, the location, and the unit in a dict
"""
print(f"[DEBUG] get_current_temperature: {location}, {unit}")
return {
"temperature": 26.1,
"location": location,
"unit": unit,
}
@tool(parse_docstring=True)
def get_temperature_date(location: str, date: str, unit: str = "celsius"):
"""Get temperature at a location and date.
Args:
location: The location to get the temperature for, in the format "City, State, Country".
date: The date to get the temperature for, in the format "Year-Month-Day".
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
Returns:
the temperature, the location, the date and the unit in a dict
"""
print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
return {
"temperature": 25.9,
"location": location,
"date": date,
"unit": unit,
}
class TestReactAgentLoop(ReactAgentLoop):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
# TODO: find better way to configure tools
cls.tools = [get_current_temperature, get_temperature_date]
super().init_class(config, tokenizer, **kwargs)
def test_react_agent(init_config):
ray.init(
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"VLLM_USE_V1": "1",
}
}
)
# =========================== 1. Init rollout manager ===========================
agent_loop_config = [
{
"_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop",
"name": "react_agent",
},
]
agent_loop_config_path = "/tmp/agent_loop_config.json"
with open(agent_loop_config_path, "w") as f:
json.dump(agent_loop_config, f)
n = 2
init_config.actor_rollout_ref.rollout.n = n
# init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path
agent_loop_manager = init_agent_loop_manager(init_config)
# =========================== 2. Generate sequences ===========================
raw_prompts = [
[
{"role": "user", "content": "How are you?"},
],
[
{"role": "user", "content": "What's the temperature in Los Angeles now?"},
],
[
{"role": "user", "content": "What's the temperature in New York now?"},
],
[
{
"role": "system",
"content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
"Current Date: 2024-09-30",
},
{"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
],
]
batch = DataProto(
non_tensor_batch={
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
"agent_name": np.array(["react_agent"] * len(raw_prompts)),
},
)
batch = batch.repeat(n)
result = agent_loop_manager.generate_sequences(prompts=batch)
assert len(result) == len(raw_prompts) * n
# Check turns
num_turns = result.non_tensor_batch["__num_turns__"]
print(f"num_turns: {num_turns}")
for i in range(len(num_turns)):
if i // n == 0:
# [user, assistant]
assert num_turns[i] == 2
else:
# [user, assistant, tool, assistant]
assert num_turns[i] == 4
# Check response_mask
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
responses = result.batch["responses"]
response_mask = result.batch["response_mask"]
attention_mask = result.batch["attention_mask"]
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
response_length = response_mask.size(1)
for i in range(len(responses)):
# response with tool response
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
response_with_obs = tokenizer.decode(valid_tokens)
# response without tool response
valid_tokens = responses[i][response_mask[i].bool()]
response_without_obs = tokenizer.decode(valid_tokens)
assert "<tool_response>" not in response_without_obs, (
f"found <tool_response> in response: {response_without_obs}"
)
assert "</tool_response>" not in response_without_obs, (
f"found </tool_response> in response: {response_without_obs}"
)
print("=========================")
print(response_with_obs)
print("---")
print(response_without_obs)
print("Test passed!")
ray.shutdown()

View File

@ -28,7 +28,7 @@ from verl import DataProto
from verl.utils.device import get_device_name
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
@ -101,7 +101,9 @@ class DataParallelPRIMERewardModel:
)
if self.ulysses_sequence_parallel_size > 1:
rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
rm_log_labels = gather_outputs_and_unpad(
rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
rm_log_labels = pad_input(
hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
).squeeze(-1)[:, -num_actions - 1 : -1]
@ -149,7 +151,7 @@ class DataParallelPRIMERewardModel:
logits=ref_output_logits, labels=input_ids_rmpad_rolled
)
ref_log_labels = gather_outpus_and_unpad(
ref_log_labels = gather_outputs_and_unpad(
ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
ref_log_labels = pad_input(

View File

@ -409,7 +409,7 @@ class RewardModelWorker(Worker):
def _forward_micro_batch(self, micro_batch):
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
@ -443,7 +443,7 @@ class RewardModelWorker(Worker):
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
reward_rmpad = gather_outpus_and_unpad(
reward_rmpad = gather_outputs_and_unpad(
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)

View File

@ -10,7 +10,7 @@ peft
pyarrow>=15.0.0
pybind11
pylatexenc
tensordict>=0.8.0,<=0.9.0
tensordict>=0.8.0,<=0.9.1,!=0.9.0
transformers==4.52.4
ray==2.46.0
wandb

View File

@ -14,7 +14,7 @@ pybind11
pylatexenc
pre-commit
ray[default]
tensordict>=0.8.0,<=0.9.0
tensordict>=0.8.0,<=0.9.1,!=0.9.0
torchdata
transformers
# vllm==0.8.4

View File

@ -12,7 +12,7 @@ pyarrow>=19.0.0
pybind11
pylatexenc
ray[default]>=2.10
tensordict>=0.8.0,<=0.9.0
tensordict>=0.8.0,<=0.9.1,!=0.9.0
torchdata
torchvision
transformers

View File

@ -17,9 +17,11 @@ import argparse
import os
import warnings
from contextlib import contextmanager
from typing import Any, Callable, ContextManager
from typing import Any, Callable, ContextManager, Optional
import numpy as np
import torch
import torch.distributed as dist
from accelerate import init_empty_weights
from megatron.core import dist_checkpointing
from megatron.core import parallel_state as mpu
@ -29,11 +31,22 @@ from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from transformers import AutoConfig
from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards
from verl.models.mcore import hf_to_mcore_config
from verl.utils.device import get_device_name, get_torch_device
from verl.utils.megatron_utils import get_model
def _init_args():
"""
Examples:
1. single rank conversion for any model:
> python converter_hf_to_mcore.py --hf_model_path %{hf_model} --output_path ${output_path}
2. distributed conversion for DeepseekV3 671B:
> torchrun --nproc_per_node 1 --nnodes 4 --node_rank ${RANK} converter_hf_to_mcore.py \
--hf_model_path %{hf_model} --output_path ${output_path}
"""
parser = argparse.ArgumentParser()
parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model")
parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model")
@ -92,7 +105,17 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model):
print("Conversion test passed!")
def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config):
@torch.inference_mode()
def convert_checkpoint_from_transformers_to_megatron(
hf_model, model, hf_config, layer_start_end: Optional[tuple[int, int]] = None
):
if layer_start_end is None:
layer_start_end = (0, len(model.decoder.layers))
layer_start, layer_end = layer_start_end
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
numel = 0
num_attention_heads = hf_config.num_attention_heads
num_key_value_heads = hf_config.num_key_value_heads
hidden_dim = hf_config.hidden_size
@ -101,50 +124,61 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
print("[WARNING] Converting GQA model")
has_qkv_bias = getattr(hf_config, "qkv_bias", False) or getattr(hf_config, "attention_bias", False)
has_share_expert = getattr(hf_config, "shared_expert_intermediate_size", None)
with torch.no_grad():
model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers, strict=True):
layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight)
if pp_rank == 0:
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
q = hf_layer.self_attn.q_proj.weight.view(
[num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]
assert len(model.decoder.layers) == (layer_end - layer_start), (
f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}"
)
for layer_idx, (layer, hf_layer) in enumerate(
zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)
):
global_layer_idx = layer_idx + layer_start
numel_cur = numel
numel += safe_copy(hf_layer.input_layernorm.weight, layer.self_attention.linear_qkv.layer_norm_weight)
q = hf_layer.self_attn.q_proj.weight.view(
[num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]
)
k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1])
v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1])
qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()
numel += safe_copy(qkv, layer.self_attention.linear_qkv.weight)
if has_qkv_bias:
q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1])
k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1])
v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
numel += safe_copy(qkv_bias, layer.self_attention.linear_qkv.bias)
if hasattr(hf_layer.self_attn, "q_norm"):
numel += safe_copy(hf_layer.self_attn.q_norm.weight.data, layer.self_attention.q_layernorm.weight)
numel += safe_copy(hf_layer.self_attn.k_norm.weight.data, layer.self_attention.k_layernorm.weight)
numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight)
numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight)
numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight)
for idx, hf_expert in enumerate(hf_layer.mlp.experts):
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"])
numel += safe_copy(hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"])
if has_share_expert:
numel += safe_copy(hf_layer.mlp.shared_expert_gate.weight, layer.mlp.shared_experts.gate_weight)
shared_fc1_weight = torch.cat(
[hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]
)
k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1])
v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1])
qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()
layer.self_attention.linear_qkv.weight.copy_(qkv)
numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)
numel += safe_copy(hf_layer.mlp.shared_expert.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)
print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}")
if has_qkv_bias:
q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1])
k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1])
v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
layer.self_attention.linear_qkv.bias.copy_(qkv_bias)
if hasattr(hf_layer.self_attn, "q_norm"):
layer.self_attention.q_layernorm.weight.copy_(hf_layer.self_attn.q_norm.weight.data)
layer.self_attention.k_layernorm.weight.copy_(hf_layer.self_attn.k_norm.weight.data)
layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight)
layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight)
layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight)
for idx, hf_expert in enumerate(hf_layer.mlp.experts):
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"].copy_(fc1_weight)
layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"].copy_(hf_expert.down_proj.weight)
if has_share_expert:
layer.mlp.shared_experts.gate_weight.copy_(hf_layer.mlp.shared_expert_gate.weight)
shared_fc1_weight = torch.cat(
[hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]
)
layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight)
layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_expert.down_proj.weight)
model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight)
model.output_layer.weight.copy_(hf_model.lm_head.weight)
if pp_rank == pp_size - 1:
numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)
numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)
return numel
def safe_copy(
@ -258,13 +292,31 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel
assert n_params == copied_numel
@torch.no_grad()
def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig):
@torch.inference_mode()
def convert_checkpoint_from_transformers_to_megatron_dpskv3(
hf_model,
model,
hf_config,
tfconfig,
layer_start_end: Optional[tuple[int, int]] = None,
):
warnings.warn("MTP model is not supported yet", stacklevel=2)
if layer_start_end is None:
layer_start_end = (0, len(model.decoder.layers))
layer_start, layer_end = layer_start_end
numel: int = 0
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
print(f"{numel=}")
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers, strict=True)):
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
if pp_rank == 0:
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
assert len(model.decoder.layers) == (layer_end - layer_start), (
f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}"
)
for layer_idx, (layer, hf_layer) in enumerate(
zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)
):
global_layer_idx = layer_idx + layer_start
numel_cur: int = numel
numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight)
@ -318,13 +370,14 @@ def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_
)
numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)
numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)
print(f"{layer_idx=} {numel=} numel this layer={numel - numel_cur}")
print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}")
assert numel - numel_cur == sum([i.numel() for i in hf_layer.state_dict().values()]), "numel mismatch"
numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)
if not hf_config.tie_word_embeddings:
numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)
print(f"{numel=}")
if pp_rank == pp_size - 1:
numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)
if not hf_config.tie_word_embeddings:
numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)
print(f"{pp_rank=} {numel=}")
return numel
@ -333,6 +386,13 @@ def noop_context() -> Any:
yield
def support_distributed_convert(hf_config: AutoConfig) -> bool:
for arch in ["DeepseekV3ForCausalLM", "Qwen3MoeForCausalLM", "Qwen2MoeForCausalLM"]:
if arch in hf_config.architectures:
return True
return False
def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False):
os.makedirs(output_path, exist_ok=True)
if len(os.listdir(output_path)) > 0 and not test:
@ -340,13 +400,22 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
return
# init torch distributed and mpu
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
if "WORLD_SIZE" not in os.environ:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("nccl")
rank = dist.get_rank()
local_rank = os.getenv("LOCAL_RANK", 0)
world_size = dist.get_world_size()
get_torch_device().set_device(f"{get_device_name()}:{local_rank}")
mpu.initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=world_size,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=1,
expert_model_parallel_size=1,
@ -357,7 +426,18 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
hf_config = AutoConfig.from_pretrained(hf_model_path)
print(hf_config, flush=True)
tfconfig = hf_to_mcore_config(hf_config, torch.bfloat16)
if world_size > 1 and not support_distributed_convert(hf_config):
raise NotImplementedError(f"distributed conversion is not supported for {hf_config.architectures} yet.")
pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, world_size)
print(f"Pipeline shards: {pipeline_shards}", flush=True)
tfconfig = hf_to_mcore_config(
hf_config,
torch.bfloat16,
num_layers_in_first_pipeline_stage=pipeline_shards[0] if len(pipeline_shards) > 1 else None,
num_layers_in_last_pipeline_stage=pipeline_shards[-1] if len(pipeline_shards) > 2 else None,
)
tfconfig.use_cpu_initialization = use_cpu_initialization
tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
@ -403,17 +483,36 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
)
hf_state_dict = hf_model.state_dict()
# distributed convert
if world_size > 1 and support_distributed_convert(hf_config):
pipeline_cumsum = np.cumsum(pipeline_shards)
layer_start = 0 if rank == 0 else pipeline_cumsum[rank - 1]
layer_end = pipeline_cumsum[rank]
if "DeepseekV3ForCausalLM" in hf_config.architectures:
numel_partial: int = convert_checkpoint_from_transformers_to_megatron_dpskv3(
hf_model, model[0].module, hf_config, tfconfig=tfconfig, layer_start_end=(layer_start, layer_end)
)
elif "Qwen3MoeForCausalLM" in hf_config.architectures or "Qwen2MoeForCausalLM" in hf_config.architectures:
numel_partial: int = convert_checkpoint_from_transformers_to_megatron(
hf_model, model[0].module, hf_config, layer_start_end=(layer_start, layer_end)
)
else:
raise NotImplementedError(f"Distributed conversion is not supported for {hf_config.architectures} yet.")
numel_tensor = torch.tensor([numel_partial]).to(get_device_name())
dist.all_reduce(numel_tensor, op=dist.ReduceOp.SUM)
numel = int(numel_tensor.cpu().item())
print(f"total numel={numel} vs {hf_model.num_parameters()=}")
if numel != hf_model.num_parameters():
warnings.warn(f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", stacklevel=1)
# load hf state dict to megatron model
if "Qwen2MoeForCausalLM" in hf_config.architectures:
elif "Qwen2MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config)
elif "DeepseekV3ForCausalLM" in hf_config.architectures:
numel: int = convert_checkpoint_from_transformers_to_megatron_dpskv3(
hf_model, model[0].module, hf_config, tfconfig=tfconfig
)
if numel != hf_model.num_parameters():
warnings.warn(f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", stacklevel=1)
convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)
elif "Qwen3MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
else:

View File

@ -37,7 +37,7 @@ install_requires = [
"pylatexenc",
"ray[default]>=2.41.0",
"torchdata",
"tensordict>=0.8.0,<=0.9.0",
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
"transformers",
"wandb",
"packaging>=20.0",
@ -48,9 +48,9 @@ PRIME_REQUIRES = ["pyext"]
GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"]
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.0", "vllm>=0.7.3,<=0.8.5"]
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.8.5"]
SGLANG_REQUIRES = [
"tensordict>=0.8.0,<=0.9.0",
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
"sglang[srt,openai]==0.4.6.post5",
"torch-memory-saver>=0.0.5",
"torch==2.6.0",

View File

@ -109,6 +109,7 @@ class WeatherTool(BaseTool):
Returns:
the temperature, the location, and the unit in a dict
"""
print(f"[DEBUG] get_current_temperature: {location}, {unit}")
return {
"temperature": 26.1,
"location": location,
@ -143,6 +144,7 @@ class WeatherToolWithData(BaseTool):
Returns:
the temperature, the location, the date and the unit in a dict
"""
print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
return {
"temperature": 25.9,
"location": location,
@ -174,11 +176,11 @@ def test_tool_agent(init_config):
tool_config = {
"tools": [
{
"class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherTool",
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool",
"config": {"type": "native"},
},
{
"class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherToolWithData",
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData",
"config": {"type": "native"},
},
]
@ -238,15 +240,29 @@ def test_tool_agent(init_config):
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
responses = result.batch["responses"]
response_mask = result.batch["response_mask"]
attention_mask = result.batch["attention_mask"]
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
response_length = response_mask.size(1)
# Decode responses with response_mask
for i in range(len(responses)):
# response with tool response
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
response_with_obs = tokenizer.decode(valid_tokens)
# response without tool response
valid_tokens = responses[i][response_mask[i].bool()]
response_str = tokenizer.decode(valid_tokens)
assert "<tool_response>" not in response_str, f"found <tool_response> in response: {response_str}"
assert "</tool_response>" not in response_str, f"found </tool_response> in response: {response_str}"
print(f"response: {response_str}")
response_without_obs = tokenizer.decode(valid_tokens)
assert "<tool_response>" not in response_without_obs, (
f"found <tool_response> in response: {response_without_obs}"
)
assert "</tool_response>" not in response_without_obs, (
f"found </tool_response> in response: {response_without_obs}"
)
print("=========================")
print(response_with_obs)
print("---")
print(response_without_obs)
print("Test passed!")
ray.shutdown()

View File

@ -27,7 +27,7 @@ from verl.protocol import DataProto
from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.ulysses import (
gather_outpus_and_unpad,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_world_size,
set_ulysses_sequence_parallel_group,
ulysses_pad_and_slice_inputs,
@ -155,7 +155,7 @@ def _hf_casual_fwd(config, sp_size, dp_size):
).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
@ -234,7 +234,7 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size):
).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)

View File

@ -175,6 +175,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \

View File

@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards
from verl.utils.megatron.pipeline_parallel import make_batch_generator
@ -45,3 +48,23 @@ def test_make_batch_generator_empty():
assert len(generators) == vpp_size
for gen in generators:
assert list(gen) == []
@pytest.mark.parametrize(
"layer_num,pp_size,gt",
[
(61, 8, [6, 8, 8, 8, 8, 8, 8, 7]),
(61, 7, [8, 9, 9, 9, 9, 9, 8]),
(61, 1, [61]),
(61, 0, ValueError),
(10, 16, ValueError),
],
)
def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt):
if isinstance(gt, list):
shards = get_dynamic_pipeline_shards(layer_num, pp_size)
assert len(shards) == len(gt) == pp_size, f"Expected {pp_size} shards, got {len(shards)}"
assert all([shard == gt[i] for i, shard in enumerate(shards)]), f"Expected shards {gt}, got {shards}"
elif issubclass(gt, Exception):
with pytest.raises(gt):
shards = get_dynamic_pipeline_shards(layer_num, pp_size)

View File

@ -666,3 +666,27 @@ class Solution:
assert "error" not in metadata_list[0]
assert metadata_list[0].get("status") != "compilation error"
assert metadata_list[0].get("status") != "runtime error"
@pytest.mark.skipif(skip_condition, reason=skip_reason)
def test_none_and_empty_stdin_passed_correctly():
"""
Tests that when stdin data is set to an empty string or None, it is still
is passed correctly to Sandbox Fusion as an empty string.
"""
echo_code = """
import sys
print(f"You said '{sys.stdin.readline().strip()}'")
"""
in_outs = {
"inputs": [None, "", "hello"],
"outputs": ["You said ''", "You said ''", "You said 'hello'"],
}
# Use a short timeout for fast tests
results, metadata_list = check_correctness(SANDBOX_URL, in_outs, echo_code, timeout=5)
assert results == [True, True, True]
assert "error" not in metadata_list[0]
assert metadata_list[0].get("status") != "compilation error"
assert metadata_list[0].get("status") != "runtime error"

View File

@ -124,6 +124,60 @@ def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb):
dist.destroy_process_group()
def test_dataproto_split_uneven():
"""Test DataProto.split with uneven splits"""
# Create test data with 10 items
input_ids = torch.randint(low=0, high=10, size=(10, 5))
attention_mask = torch.ones(10, 5)
data = {"input_ids": input_ids, "attention_mask": attention_mask}
dataproto = DataProto.from_single_dict(data)
# Test split with size 3 (should create chunks of [3, 3, 3, 1])
splits = dataproto.split(3)
assert len(splits) == 4
assert len(splits[0]) == 3
assert len(splits[1]) == 3
assert len(splits[2]) == 3
assert len(splits[3]) == 1
reconstructed = DataProto.concat(splits)
torch.testing.assert_close(reconstructed.batch["input_ids"], dataproto.batch["input_ids"])
torch.testing.assert_close(reconstructed.batch["attention_mask"], dataproto.batch["attention_mask"])
# Test split with size equal to length (should create one chunk)
splits = dataproto.split(10)
assert len(splits) == 1
assert len(splits[0]) == 10
# Test split with size larger than length (should create one chunk with all data)
splits = dataproto.split(15)
assert len(splits) == 1
assert len(splits[0]) == 10
# Test with non-tensor batch data
import numpy as np
data_with_non_tensor = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": np.array([f"label_{i}" for i in range(10)], dtype=object),
}
dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor)
splits = dataproto_with_non_tensor.split(3)
assert len(splits) == 4
assert len(splits[0]) == 3
assert len(splits[1]) == 3
assert len(splits[2]) == 3
assert len(splits[3]) == 1
# Verify non-tensor data integrity
reconstructed = DataProto.concat(splits)
np.testing.assert_array_equal(
reconstructed.non_tensor_batch["labels"], dataproto_with_non_tensor.non_tensor_batch["labels"]
)
def test_seqlen_balancing_distributed_params(tmp_path):
world_size = 2
init_file = tmp_path / "dist_init"

View File

@ -0,0 +1,57 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
_TENSOR_1MB = torch.zeros(512, 512)
_BYTES_1MB = 1 << 20
@pytest.mark.parametrize(
"named_tensors, bucket_size_mb, gt_groups",
[
(
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
0.5 * _BYTES_1MB,
[["a"], ["b"]],
),
(
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
1 * _BYTES_1MB,
[["a"], ["b"]],
),
(
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
1.5 * _BYTES_1MB,
[["a"], ["b"]],
),
(
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
2 * _BYTES_1MB,
[["a", "b"]],
),
],
)
def test_get_named_tensor_buckets(named_tensors, bucket_size_mb, gt_groups: list[list[str]]):
named_tensors_iter = iter(named_tensors)
groups = list(get_named_tensor_buckets(named_tensors_iter, bucket_size_mb))
assert len(groups) == len(gt_groups)
for group, gt_group in zip(groups, gt_groups, strict=True):
assert len(group) == len(gt_group)
for (name, _), (gt_name) in zip(group, gt_group, strict=True):
assert name == gt_name

View File

@ -158,6 +158,8 @@ def get_rollout_config(
"prompt_length": max_prompt_length,
"response_length": max_response_length,
"tensor_model_parallel_size": tensor_parallel_size,
# set to 128MB only for testing
"update_weights_bucket_megabytes": 128,
"multi_turn": {
"max_assistant_turns": 4,
"max_user_turns": 4,

View File

@ -13,5 +13,9 @@
# limitations under the License.
from .agent_loop import AgentLoopBase, AgentLoopManager
from .single_turn_agent_loop import SingleTurnAgentLoop
from .tool_agent_loop import ToolAgentLoop
_ = [SingleTurnAgentLoop, ToolAgentLoop]
__all__ = ["AgentLoopBase", "AgentLoopManager"]

View File

@ -19,11 +19,12 @@ import random
from abc import ABC, abstractmethod
from typing import Any
import hydra
import numpy as np
import ray
import torch
from cachetools import LRUCache
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel
from tensordict import TensorDict
from transformers import AutoTokenizer
@ -120,29 +121,43 @@ class AgentLoopOutput(BaseModel):
metrics: AgentLoopMetrics
# make hydra.utils.instantiate happy
class _DummyConfig:
def __init__(self, config: DictConfig) -> None:
self.config = config
class AgentLoopBase(ABC):
"""An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various
environments."""
_class_initialized = False
def __init__(self, config: DictConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer):
"""Initialize agent loop.
def __init__(
self, trainer_config: _DummyConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer, **kwargs
):
"""Initialize agent loop, each sample will have its own loop instance.
Args:
config (DictConfig): YAML config.
trainer_config (_DummyConfig): trainer config.
server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.
tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
"""
self.config = config
self.init_class(trainer_config.config, tokenizer, **kwargs)
self.config = trainer_config.config
self.server_manager = server_manager
self.tokenizer = tokenizer
self.loop = asyncio.get_running_loop()
self.init_class(config, tokenizer)
@classmethod
def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer):
"""Initialize class state shared across all instances."""
def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs):
"""This is used to do heavy initialization work that should shared across all instances. It's only called once.
Args:
config (DictConfig): trainer config.
tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
**kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`.
"""
if cls._class_initialized:
return
cls._class_initialized = True
@ -161,6 +176,25 @@ class AgentLoopBase(ABC):
raise NotImplementedError
"""Agent loop registry: key is agent_name, value is a dict of agent loop config
used by hydra.utils.instantiate to initialize agent loop instance.
https://hydra.cc/docs/advanced/instantiate_objects/overview/
"""
_agent_loop_registry: dict[str, dict] = {}
def register(agent_name: str):
"""Register agent loop class."""
def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]:
fqdn = f"{subclass.__module__}.{subclass.__qualname__}"
_agent_loop_registry[agent_name] = {"_target_": fqdn}
return subclass
return decorator
@ray.remote
class AgentLoopWorker:
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
@ -180,6 +214,13 @@ class AgentLoopWorker:
local_path = copy_to_local(config.actor_rollout_ref.model.path)
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path
if agent_loop_config_path:
agent_loop_configs = OmegaConf.load(agent_loop_config_path)
for agent_loop_config in agent_loop_configs:
_agent_loop_registry[agent_loop_config.name] = agent_loop_config
trace_config = config.trainer.get("rollout_trace", {})
trace_config = self.config.actor_rollout_ref.rollout.get("trace", {})
RolloutTraceConfig.init(
self.config.trainer.project_name,
@ -260,36 +301,20 @@ class AgentLoopWorker:
validate=trajectory["validate"],
name="agent_loop",
):
agent_loop_class = self.get_agent_loop_class(agent_name)
agent_loop = agent_loop_class(self.config, self.server_manager, self.tokenizer)
assert agent_name in _agent_loop_registry, (
f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}"
)
agent_loop_config = _agent_loop_registry[agent_name]
agent_loop = hydra.utils.instantiate(
config=agent_loop_config,
trainer_config=_DummyConfig(config=self.config),
server_manager=self.server_manager,
tokenizer=self.tokenizer,
)
output = await agent_loop.run(messages, sampling_params)
return output
def get_agent_loop_class(self, agent_name: str) -> type[AgentLoopBase]:
"""Get the appropriate agent loop class based on agent name.
Factory method that returns the correct agent loop class implementation
for the specified agent type.
Args:
agent_name (str): Name of the agent type ('single_turn_agent' or 'tool_agent').
Returns:
Type[AgentLoopBase]: Agent loop class corresponding to the agent name.
Raises:
ValueError: If the agent_name is not recognized.
"""
# TODO: add tool agent registrary
from verl.experimental.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop
from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop
if agent_name == "single_turn_agent":
return SingleTurnAgentLoop
elif agent_name == "tool_agent":
return ToolAgentLoop
raise ValueError(f"Unknown agent_name: {agent_name}")
def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad

View File

@ -16,20 +16,21 @@ import os
from typing import Any
from uuid import uuid4
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
from verl.utils.profiler import simple_timer
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@register("single_turn_agent")
class SingleTurnAgentLoop(AgentLoopBase):
"""Naive agent loop that only do single turn chat completion."""
def __init__(self, config, server_manager, tokenizer):
super().__init__(config, server_manager, tokenizer)
self.prompt_length = config.actor_rollout_ref.rollout.prompt_length
self.response_length = config.actor_rollout_ref.rollout.response_length
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
self.response_length = self.config.actor_rollout_ref.rollout.response_length
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
metrics = {}

View File

@ -15,14 +15,11 @@ import asyncio
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any
from uuid import uuid4
import regex as re
from pydantic import BaseModel
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser
from verl.tools.utils.tool_registry import initialize_tools_from_config
from verl.utils.profiler import simple_timer
from verl.utils.rollout_trace import rollout_trace_op
@ -31,68 +28,10 @@ logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class FunctionCall(BaseModel):
arguments: str
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: str
"""The name of the function to call."""
class ToolParser(ABC):
@abstractmethod
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
"""Extract tool calls from the responses.
Args:
responses_ids (List[int]): The ids of the responses.
Returns:
List[FunctionCall]: The extracted tool calls.
"""
raise NotImplementedError
class HermesToolParser(ToolParser):
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py"""
def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
@rollout_trace_op
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
loop = asyncio.get_running_loop()
text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
return []
matches = self.tool_call_regex.findall(text)
function_calls = []
for match in matches:
try:
function_call = json.loads(match)
name, arguments = function_call["name"], function_call["arguments"]
function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))
except Exception as e:
logger.error(f"Failed to decode tool call: {e}")
return function_calls
@register("tool_agent")
class ToolAgentLoop(AgentLoopBase):
def __init__(self, config, server_manager, tokenizer):
super().__init__(config, server_manager, tokenizer)
@classmethod
def init_class(cls, config, tokenizer):
def init_class(cls, config, tokenizer, **kwargs):
if cls._class_initialized:
return
cls._class_initialized = True
@ -109,7 +48,7 @@ class ToolAgentLoop(AgentLoopBase):
tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
cls.tools = {tool.name: tool for tool in tool_list}
cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
cls.tool_parser = cls.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format)
cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer)
print(f"Initialized tools: {cls.tools}")
cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
@ -151,7 +90,7 @@ class ToolAgentLoop(AgentLoopBase):
break
# no tool calls
tool_calls = await self.tool_parser.extract_tool_calls(response_ids)
_, tool_calls = await self.tool_parser.extract_tool_calls(response_ids)
if not tool_calls:
break
@ -225,12 +164,3 @@ class ToolAgentLoop(AgentLoopBase):
"role": "tool",
"content": tool_response,
}
@classmethod
def get_tool_parser(cls, name: str) -> ToolParser:
tool_parsers = {
"hermes": HermesToolParser,
}
if name not in tool_parsers:
raise ValueError(f"Unknown tool parser: {name}")
return tool_parsers[name](cls.tokenizer)

View File

@ -0,0 +1,106 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
from abc import ABC, abstractmethod
import regex as re
from pydantic import BaseModel
from verl.utils.rollout_trace import rollout_trace_op
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class FunctionCall(BaseModel):
arguments: str
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: str
"""The name of the function to call."""
class ToolParser(ABC):
_registry: dict[str, type["ToolParser"]] = {}
def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer
@abstractmethod
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
"""Extract tool calls from the responses.
Args:
responses_ids (List[int]): The ids of the responses.
Returns:
Tuple[str, List[FunctionCall]]: Content and extracted tool calls.
"""
raise NotImplementedError
@classmethod
def get_tool_parser(cls, name: str, tokenizer):
if name not in cls._registry:
raise ValueError(f"Unknown tool parser: {name}")
return cls._registry[name](tokenizer)
@classmethod
def register(cls, name: str):
def decorator(subclass: type[ToolParser]) -> type[ToolParser]:
cls._registry[name] = subclass
return subclass
return decorator
@ToolParser.register("hermes")
class HermesToolParser(ToolParser):
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py"""
def __init__(self, tokenizer) -> None:
super().__init__(tokenizer)
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
@rollout_trace_op
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
loop = asyncio.get_running_loop()
text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
return text, []
matches = self.tool_call_regex.findall(text)
function_calls = []
for match in matches:
try:
function_call = json.loads(match)
name, arguments = function_call["name"], function_call["arguments"]
function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))
except Exception as e:
logger.error(f"Failed to decode tool call: {e}")
# remaing text exclude tool call tokens
content = self.tool_call_regex.sub("", text)
return content, function_calls

View File

@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -32,6 +32,16 @@ python -m verl.model_merger merge \
--target_dir /path/to/merged_hf_model
```
or use distribtued merge for large models like dpskv3 671B
```sh
torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\
--backend megatron \
--local_dir ./checkpoints/global_step_1/actor \
--target_dir /path/to/merged_hf_model
```
For more details, please refer to documentation:
https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model
"""

View File

@ -45,6 +45,7 @@ def parse_args():
action="store_true",
help="Whether to tie word embedding weights (currently only Megatron supported)",
)
base_op_parser.add_argument("--trust-remote-code", action="store_true", help="Whether to trust remote code")
base_op_parser.add_argument(
"--is-value-model",
action="store_true",
@ -88,6 +89,7 @@ class ModelMergerConfig:
private: bool = False
test_hf_dir: Optional[str] = None
tie_word_embedding: bool = False
trust_remote_code: bool = False
is_value_model: bool = False
local_dir: Optional[str] = None
hf_model_config_path: Optional[str] = None
@ -107,6 +109,7 @@ def generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig:
"operation": args.operation,
"backend": args.backend,
"tie_word_embedding": args.tie_word_embedding,
"trust_remote_code": args.trust_remote_code,
"is_value_model": args.is_value_model,
"local_dir": args.local_dir,
"hf_model_config_path": os.path.join(args.local_dir, "huggingface"),
@ -161,7 +164,9 @@ class BaseModelMerger(ABC):
def __init__(self, config: ModelMergerConfig):
self.config = config
self.hf_model_config_path = config.hf_model_config_path
self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path)
self.model_config = AutoConfig.from_pretrained(
self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code
)
def get_transformers_auto_model_class(self):
if "ForTokenClassification" in self.model_config.architectures[0]:
@ -250,7 +255,9 @@ class BaseModelMerger(ABC):
def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]):
auto_model_class = self.get_transformers_auto_model_class()
with init_empty_weights():
model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16)
model = auto_model_class.from_config(
self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code
)
model.to_empty(device="cpu")
model = self.patch_model_generation_config(model)
@ -263,8 +270,8 @@ class BaseModelMerger(ABC):
del state_dict
del model
processor = hf_processor(self.hf_model_config_path)
tokenizer = hf_tokenizer(self.hf_model_config_path)
processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
if processor is not None:
print(f"Saving processor to {self.config.target_dir}")
processor.save_pretrained(self.config.target_dir)

View File

@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, ContextManager
import numpy as np
import torch
import torch.distributed as dist
from accelerate import init_empty_weights
from megatron.core import mpu
from megatron.core.models.gpt.gpt_model import ModelType
@ -30,9 +33,10 @@ from transformers import (
)
from verl.models.mcore import hf_to_mcore_config
from verl.utils.device import get_nccl_backend
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing
from verl.utils.megatron_utils import get_model
from verl.utils.tokenizer import hf_processor, hf_tokenizer
from .base_model_merger import BaseModelMerger, ModelMergerConfig
@ -42,6 +46,50 @@ def noop_context() -> Any:
yield
def get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]:
"""Calculate the pipeline sharding configuration for Megatron-LM.
Args:
layer_num: Total number of layers in the model.
pp_size: Number of pipeline parallel ranks.
Returns:
layer number of each pp rank. Make the sharding of the pipeline as uniform as possible.
"""
if layer_num < pp_size:
raise ValueError(f"layer_num {layer_num} must be greater than pp_size {pp_size}.")
if pp_size < 1:
raise ValueError(f"pp_size must be at least 1, got {pp_size}.")
if pp_size == 1:
return [layer_num]
if pp_size == 2:
return [
layer_num // 2,
layer_num - layer_num // 2,
]
middle_size = pp_size - 2
shards_strategy = []
for middle_layer_num in range(layer_num):
first_last_layer_num = layer_num - middle_layer_num * middle_size
first_layer_num = first_last_layer_num // 2
last_layer_num = first_last_layer_num - first_last_layer_num // 2
if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num:
shards_strategy.append(
(
[first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num],
abs(first_layer_num - middle_layer_num),
)
)
# sort by diff of layer_num, to make it as uniform as possible
res = sorted(shards_strategy, key=lambda x: x[1])[0][0]
assert sum(res) == layer_num, f"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}"
return res
class MegatronModelMerger(BaseModelMerger):
"""
Model merger for Megatron-LM distributed checkpoints.
@ -87,19 +135,31 @@ class MegatronModelMerger(BaseModelMerger):
def __init__(self, config: ModelMergerConfig):
super().__init__(config)
# Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
if "WORLD_SIZE" not in os.environ:
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group(get_nccl_backend())
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
local_rank = os.environ.get("LOCAL_RANK", 0)
get_torch_device().set_device(f"{get_device_name()}:{local_rank}")
mpu.initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=self.world_size,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=1,
expert_model_parallel_size=1,
)
model_parallel_cuda_manual_seed(0)
self.hf_config = AutoConfig.from_pretrained(self.config.hf_model_config_path)
self.hf_config = AutoConfig.from_pretrained(
self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code
)
print(self.hf_config, flush=True)
self.params_mapping = {
@ -107,6 +167,9 @@ class MegatronModelMerger(BaseModelMerger):
# NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the
# longer key within the containing relationship is processed first.
"embedding.word_embeddings": "model.embed_tokens",
# input layer norm for dpskv3
"input_layernorm.weight": "input_layernorm.weight",
"input_layernorm.bias": "input_layernorm.bias",
# attn
"self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight",
"self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias",
@ -140,6 +203,11 @@ class MegatronModelMerger(BaseModelMerger):
"output_layer": "lm_head",
}
if "Qwen2MoeForCausalLM" in self.hf_config.architectures:
self.params_mapping["mlp.shared_experts.linear_fc1"] = "mlp.shared_expert.gate_up_proj"
self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj"
self.params_mapping["mlp.shared_experts.gate_weight"] = "mlp.shared_expert_gate.weight"
def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]:
"""_summary_
Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory.
@ -152,7 +220,15 @@ class MegatronModelMerger(BaseModelMerger):
"""
# init hf config
tf_config = hf_to_mcore_config(self.hf_config, torch.bfloat16)
self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size)
print(f"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}")
tf_config = hf_to_mcore_config(
self.hf_config,
torch.bfloat16,
num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None,
num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None,
)
tf_config.use_cpu_initialization = self.config.use_cpu_initialization
tie_word_embeddings = getattr(self.hf_config, "tie_word_embeddings", False)
@ -273,7 +349,11 @@ class MegatronModelMerger(BaseModelMerger):
def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
state_dict = {}
layers_cum = 0
if self.world_size > 1:
pipeline_cumsum = np.cumsum(self.pipeline_shards)
layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]
print(f"{layers_cum=}")
for model_state_dict in model_state_dict_list:
layers_handled = 0
keys = model_state_dict.keys()
@ -297,6 +377,15 @@ class MegatronModelMerger(BaseModelMerger):
else:
warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2)
if "mlp.experts." in hf_name and ".weight" in hf_name:
name_prefix, expert_id = hf_name.split(".weight")
for proj in ["gate_up", "down"]:
if f"{proj}_proj" in hf_name:
hf_name = hf_name.replace(
f"mlp.experts.{proj}_proj.weight{expert_id}",
f"mlp.experts.{expert_id}.{proj}_proj.weight",
)
tensor = model_state_dict[key]
split_tensor = self._split_tensors(
key, tensor, self.hf_config, is_value_model=self.config.is_value_model
@ -321,6 +410,75 @@ class MegatronModelMerger(BaseModelMerger):
return state_dict
def save_hf_model_and_tokenizer(self, merged_state_dict):
if self.world_size == 1:
return super().save_hf_model_and_tokenizer(merged_state_dict)
from safetensors.torch import save_file
layer_num = self.hf_config.num_hidden_layers
# FIXME: make configurable
saves_per_layer = 1 if layer_num < 30 else 2
saves_total = saves_per_layer * layer_num
saves_indexes = {}
# calculate the layer start index and key chunks
layer_this_rank = self.pipeline_shards[self.rank]
pipeline_cumsum = np.cumsum(self.pipeline_shards)
layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]
keys = list(merged_state_dict.keys())
keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer)
numel = 0
assert len(keys_chunk) == layer_this_rank * saves_per_layer, (
f"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}."
)
# save to model shards manually
target_dir = Path(self.config.target_dir)
for i, keys in enumerate(keys_chunk):
sd_to_save = {k: merged_state_dict[k] for k in keys}
numel += sum([sd_to_save[i].numel() for i in sd_to_save])
save_idx = layer_start * saves_per_layer + i
save_path = target_dir / f"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors"
save_file(sd_to_save, save_path)
for k in keys:
saves_indexes[k] = str(save_path.name)
tensor = torch.tensor([numel]).to(get_device_name())
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
numel = tensor.cpu().item()
all_save_indexes = [{} for _ in range(self.world_size)]
dist.all_gather_object(all_save_indexes, saves_indexes)
saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()}
if self.rank == 0:
with open(target_dir / "model.safetensors.index.json", "w") as f:
json.dump(
{
"metadata": {
"total_size": numel,
},
"weight_map": saves_indexes,
},
f,
indent=4,
)
print(f"model saved to {target_dir} with {numel=}")
self.model_config.save_pretrained(self.config.target_dir)
processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
if processor is not None:
print(f"Saving processor to {self.config.target_dir}")
processor.save_pretrained(self.config.target_dir)
if tokenizer is not None:
print(f"Saving tokenizer to {self.config.target_dir}")
tokenizer.save_pretrained(self.config.target_dir)
def merge_and_save(self):
from verl.utils.megatron_utils import get_dist_checkpoint_path
@ -370,6 +528,7 @@ class MegatronModelMerger(BaseModelMerger):
megatron_name = megatron_name.replace("decoder", "model")
param_name = megatron_name.replace(m_name, v_name)
return param_name
return None # Return None if no mapping found

View File

@ -736,11 +736,7 @@ class DataProto:
Returns:
List[DataProto]: a list of DataProto after splitting
"""
assert len(self) % split_size == 0, (
f"only support equal split. Got size of DataProto {len(self)} and chunk {split_size}."
)
chunks = len(self) // split_size
return self.chunk(chunks)
return [self[i : i + split_size] for i in range(0, len(self), split_size)]
@staticmethod
def concat(data: list["DataProto"]) -> "DataProto":

View File

@ -57,7 +57,7 @@ class WorkerHelper:
return sock.getsockname()[1]
def get_availale_master_addr_port(self):
return self._get_node_ip(), str(self._get_free_port())
return self._get_node_ip().strip("[]"), str(self._get_free_port())
# we assume that in each WorkerGroup, there is a Master Worker

View File

@ -127,9 +127,11 @@ actor_rollout_ref:
calculate_log_probs: false
agent:
num_workers: 8
agent_loop_config_path: null
custom_async_server:
path: null
name: null
update_weights_bucket_megabytes: 512
trace:
backend: null
token2text: false

View File

@ -170,6 +170,18 @@ agent:
# Number of agent loop workers
num_workers: 8
# custom agent loop config path, which should contain list of configs to intialize AgentLoop instances.
# https://hydra.cc/docs/advanced/instantiate_objects/overview/
#
# - name: react_agent
# _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop
# tools: ["get_current_temperature"]
# - name: math_expression
# _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
# min_terms: 2
# max_terms: 6
agent_loop_config_path: null
# custom async server configs
custom_async_server:
@ -179,6 +191,20 @@ agent:
# Class name of the custom async server class (e.g. AsyncvLLMServer)
name: null
# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations.
# This parameter controls the maximum payload size for a single weight update request.
# Reference: https://github.com/volcengine/verl/pull/2418
# Currently only supported in SGLang rollout implementations
# Larger values may improve throughput but increase memory overhead
# Detailed performance comparison:
# https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720
# Default value (512MB) is optimized for typical GPU memory configurations
# For the best performance of `rebuild_cuda_tensor`, it is recommended to:
# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`
# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`
# when using Tensor Parallelism (TP) >= 8.
update_weights_bucket_megabytes: 512
# trace rollout data
trace:

View File

@ -62,7 +62,7 @@ from verl.utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
from verl.utils.tracking import Tracking
from verl.utils.ulysses import (
gather_outpus_and_unpad,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_world_size,
ulysses_pad_and_slice_inputs,
)
@ -406,7 +406,7 @@ class FSDPSFTTrainer:
input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)
loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)
# Gather and unpad for sequence parallelism
loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# This is the loss collected from all ulysses ranks
full_loss = pad_input(

View File

@ -64,8 +64,8 @@ def run_ppo(config) -> None:
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
if (
is_cuda_available
and OmegaConf.select(config.trainer, "profile_steps") is not None
and len(OmegaConf.select(config.trainer, "profile_steps")) > 0
and config.trainer.get("profile_steps") is not None
and len(config.trainer.get("profile_steps", [])) > 0
):
nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
@ -106,9 +106,7 @@ class TaskRunner:
from verl.utils.fs import copy_to_local
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
# Download the checkpoint from HDFS to the local machine.
@ -125,14 +123,6 @@ class TaskRunner:
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# Version validation for vllm.
if config.actor_rollout_ref.rollout.name in ["vllm"]:
from verl.utils.vllm_utils import is_version_ge
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0:
if not is_version_ge(pkg="vllm", minver="0.7.3"):
raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3")
# Define worker classes based on the actor strategy.
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
@ -550,16 +550,6 @@ class RayPPOTrainer:
"validation gen temperature should be greater than 0 when enabling do_sample"
)
# check multi_turn with tool config
if config.actor_rollout_ref.rollout.multi_turn.enable:
assert (
config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None
or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path is not None
), (
"tool_config_path or interaction_config_path must be set when enabling multi_turn with tool, "
"due to no role-playing support"
)
print("[validate_config] All configuration checks passed successfully!")
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
@ -1049,6 +1039,28 @@ class RayPPOTrainer:
else:
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
def _start_profiling(self, do_profile: bool) -> None:
"""Start profiling for all worker groups if profiling is enabled."""
if do_profile:
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
if self.use_reference_policy:
self.ref_policy_wg.start_profile()
if self.use_critic:
self.critic_wg.start_profile()
if self.use_rm:
self.rm_wg.start_profile()
def _stop_profiling(self, do_profile: bool) -> None:
"""Stop profiling for all worker groups if profiling is enabled."""
if do_profile:
self.actor_rollout_wg.stop_profile()
if self.use_reference_policy:
self.ref_policy_wg.stop_profile()
if self.use_critic:
self.critic_wg.stop_profile()
if self.use_rm:
self.rm_wg.stop_profile()
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch["attention_mask"]
@ -1118,14 +1130,7 @@ class RayPPOTrainer:
else False
)
with marked_timer("start_profile", timing_raw):
if do_profile:
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
if self.use_reference_policy:
self.ref_policy_wg.start_profile()
if self.use_critic:
self.critic_wg.start_profile()
if self.use_rm:
self.rm_wg.start_profile()
self._start_profiling(do_profile)
batch: DataProto = DataProto.from_single_dict(batch_dict)
@ -1319,7 +1324,6 @@ class RayPPOTrainer:
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
print(batch.batch.keys())
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
@ -1366,14 +1370,7 @@ class RayPPOTrainer:
self._save_checkpoint()
with marked_timer("stop_profile", timing_raw):
if do_profile:
self.actor_rollout_wg.stop_profile()
if self.use_reference_policy:
self.ref_policy_wg.stop_profile()
if self.use_critic:
self.critic_wg.stop_profile()
if self.use_rm:
self.rm_wg.stop_profile()
self._stop_profiling(do_profile)
steps_duration = timing_raw["step"]
self.max_steps_duration = max(self.max_steps_duration, steps_duration)

View File

@ -14,10 +14,18 @@
import re
_SOLUTION_CLIP_CHARS = 300
def extract_solution(solution_str, method="strict"):
assert method in ["strict", "flexible"]
# Optimization: Regular expression matching on very long strings can be slow.
# For math problems, the final answer is usually at the end.
# We only match on the last 300 characters, which is a safe approximation for 300 tokens.
if len(solution_str) > _SOLUTION_CLIP_CHARS:
solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]
if method == "strict":
# this also tests the formatting of the model
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)

View File

@ -67,7 +67,7 @@ SUPPORTED_LANGUAGES = [
def call_sandbox_api(
sandbox_fusion_url: str,
code: str,
stdin: str,
stdin: Optional[str],
compile_timeout: int,
run_timeout: int,
memory_limit_mb: int,
@ -259,9 +259,9 @@ def _execute_user_function():
# Attempt to instantiate and get method.
# Errors (e.g., Solution not a class, instantiation fails, method missing)
# will be caught by the broad except block below.
_solution_instance = _Solution_class()
_solution_instance = _Solution_class()
_target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME)
if not _target_callable:
sys.stderr.write(f"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\n")
return None, True # result, error_occurred
@ -286,10 +286,11 @@ if __name__ == '__main__':
print(str(_result))
# Optional: To explicitly exit with an error code if the sandbox relies on it
# else:
# sys.exit(1)
# sys.exit(1)
"""
current_generation_code = wrapper_code
stdin = None if stdin_data is None else str(stdin_data)
try:
if concurrent_semaphore:
# logger.debug(f"Case {case_index + 1}: Attempting to acquire semaphore.")
@ -298,7 +299,7 @@ if __name__ == '__main__':
api_response, error_msg = call_sandbox_api(
sandbox_fusion_url=sandbox_fusion_url,
code=current_generation_code,
stdin=str(stdin_data),
stdin=stdin,
compile_timeout=timeout,
run_timeout=timeout,
memory_limit_mb=memory_limit_mb,
@ -309,7 +310,7 @@ if __name__ == '__main__':
api_response, error_msg = call_sandbox_api(
sandbox_fusion_url=sandbox_fusion_url,
code=current_generation_code,
stdin=str(stdin_data),
stdin=stdin,
compile_timeout=timeout,
run_timeout=timeout,
memory_limit_mb=memory_limit_mb,
@ -322,7 +323,7 @@ if __name__ == '__main__':
metadata = {
"case_index": case_index,
"input": str(stdin_data),
"input": stdin,
"expected_output": str(expected_output),
"api_request_error": error_msg,
"api_response": None,
@ -346,7 +347,7 @@ if __name__ == '__main__':
# Log code and input only on error for brevity
generation_to_log = generation[:200] + "..." if len(generation) > 200 else generation
logger.error(f"Case {case_index}: code: {generation_to_log}")
logger.error(f"Case {case_index}: input: {str(stdin_data)}")
logger.error(f"Case {case_index}: input: {stdin}")
elif api_response:
# --- Add debug logging ---
logger.debug(f"Case {case_index}: API Response: {api_response}")

View File

@ -234,7 +234,13 @@ class Gather(torch.autograd.Function):
)
def gather_outpus_and_unpad(
def gather_outpus_and_unpad(*args, **kwargs):
raise RuntimeError(
"please use verl.utils.ulysses.gather_outputs_and_unpad instead of verl.utils.ulysses.gather_outpus_and_unpad"
)
def gather_outputs_and_unpad(
x: Tensor,
gather_dim: int,
unpad_dim: int = None,

View File

@ -33,7 +33,7 @@ from verl.utils.profiler import GPUMemoryLogger
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
from verl.workers.actor import BasePPOActor
if is_cuda_available:
@ -203,14 +203,14 @@ class DataParallelPPOActor(BasePPOActor):
# gather log_prob if sp > 1
if self.use_ulysses_sp:
# gather and unpad for the ulysses sp
log_probs = gather_outpus_and_unpad(
log_probs = gather_outputs_and_unpad(
log_probs,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size,
)
if calculate_entropy:
entropy_rmpad = gather_outpus_and_unpad(
entropy_rmpad = gather_outputs_and_unpad(
entropy_rmpad,
gather_dim=0,
unpad_dim=0,

View File

@ -295,7 +295,15 @@ class MegatronPPOActor(BasePPOActor):
Returns:
"""
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
select_keys = [
"responses",
"input_ids",
"attention_mask",
"response_mask",
"position_ids",
"old_log_probs",
"advantages",
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
@ -395,8 +403,7 @@ class MegatronPPOActor(BasePPOActor):
responses = data["responses"]
response_length = responses.size(1)
attention_mask = data["attention_mask"].to(bool)
response_mask = attention_mask[:, -response_length:]
response_mask = data["response_mask"].to(bool)
loss_agg_mode = self.config.loss_agg_mode
# compute policy loss

View File

@ -31,7 +31,7 @@ from verl.utils.profiler import GPUMemoryLogger
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
from verl.utils.torch_functional import masked_mean
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from verl.workers.critic import BasePPOCritic
if is_cuda_available:
@ -113,7 +113,7 @@ class DataParallelPPOCritic(BasePPOCritic):
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
values_rmpad = gather_outpus_and_unpad(
values_rmpad = gather_outputs_and_unpad(
values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)

View File

@ -1438,7 +1438,7 @@ class RewardModelWorker(Worker, DistProfilerExtension):
unpad_input,
)
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
@ -1481,7 +1481,7 @@ class RewardModelWorker(Worker, DistProfilerExtension):
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
reward_rmpad = gather_outpus_and_unpad(
reward_rmpad = gather_outputs_and_unpad(
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
import pickle
from typing import Any, Optional
from typing import Any, Iterator, Optional
import numpy as np
import torch
@ -66,3 +66,43 @@ def broadcast_pyobj(
serialized_data = bytes(tensor_data.cpu().numpy())
data = pickle.loads(serialized_data)
return data
def get_named_tensor_buckets(
iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int
) -> Iterator[list[tuple[str, torch.Tensor]]]:
"""
Group tensors into buckets based on a specified size in megabytes.
Args:
iterable: An iterator of tuples containing tensor names and tensors.
bucket_bytes: The maximum size of each bucket in bytes.
Yields:
Lists of tuples, where each tuple contains a tensor name and its corresponding tensor.
Example:
>>> tensors = [('tensor1', torch.randn(1000, 1000)), ('tensor2', torch.randn(2000, 2000))]
>>> for bucket in get_named_tensor_buckets(tensors, bucket_size_mb=10):
... print(bucket)
[('tensor1', tensor(...)), ('tensor2', tensor(...))]
"""
if bucket_bytes <= 0:
raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}")
current_bucket = []
current_size = 0
for name, tensor in iterable:
tensor_size = tensor.element_size() * tensor.numel()
if current_size + tensor_size > bucket_bytes:
if current_bucket:
yield current_bucket
current_bucket = [(name, tensor)]
current_size = tensor_size
else:
current_bucket.append((name, tensor))
current_size += tensor_size
if current_bucket:
yield current_bucket

View File

@ -35,6 +35,7 @@ from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from verl.utils.torch_functional import check_device_is_available
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
from .base import BaseShardingManager
@ -113,32 +114,63 @@ class FSDPSGLangShardingManager(BaseShardingManager):
# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
named_tensors = [(k, v) for k, v in params.items()]
load_format = None
for tensor_index, (name, tensor) in enumerate(named_tensors):
serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor))
# convert megabytes to bytes
update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20
for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):
# On each rank, serialize a batch of (name, tensor) tuples.
# named_tensors_batch will be a list like:
# [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...]
named_tensors_batch = [
(name, MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor)))
for name, tensor in batch
]
if self.device_mesh["infer_tp"].get_local_rank() == 0:
gathered_serialized_tensors = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])]
# On rank 0, prepare a list to hold the gathered batches from all ranks.
gathered_serialized_batches = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])]
else:
gathered_serialized_tensors = None
gathered_serialized_batches = None
# Gather the named_tensors_batch from all ranks to rank 0.
# After this, on rank 0, gathered_serialized_batches will be a list of lists:
# [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0
# [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1
# ... ]
# On other ranks, gathered_serialized_batches will be None.
dist.gather_object(
obj=serialized_tensor,
object_gather_list=gathered_serialized_tensors,
obj=named_tensors_batch,
object_gather_list=gathered_serialized_batches,
dst=self.device_mesh["infer_tp"].mesh.tolist()[0],
group=self.device_mesh["infer_tp"].get_group(),
)
if self.device_mesh["infer_tp"].get_local_rank() == 0:
# Use zip(*) to "transpose" the data structure.
# This groups the serialized parts for each individual tensor across all TP ranks.
# Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]]
# to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ]
logical_tensors = zip(*gathered_serialized_batches, strict=True)
await self.inference_engine.update_weights_from_tensor(
named_tensors=[
# 'tensor_group' represents a single logical tensor's data from all ranks.
(
name,
LocalSerializedTensor(values=gathered_serialized_tensors),
tensor_group[0][0], # Get the name from the first rank's data.
LocalSerializedTensor(
# 'rank_part' is the (name, serialized_tensor) tuple from one specific rank.
values=[rank_part[1] for rank_part in tensor_group]
),
)
for tensor_group in logical_tensors
# each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) )
],
load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1,
flush_cache=False,
)
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.flush_cache()
async def release_memory(self):
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
await self.inference_engine.release_memory_occupation()

View File

@ -37,6 +37,7 @@ from verl.utils.megatron_utils import (
per_tensor_generator,
)
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
from .base import BaseShardingManager
@ -130,37 +131,76 @@ class MegatronSGLangShardingManager(BaseShardingManager):
loop.run_until_complete(self.sleep())
async def update_weights(self, params):
"""
Update model weights using tensor buckets, similar to THUDM/slime's implementation.
Notes:
- For the best performance of `rebuild_cuda_tensor`, it is recommended to:
1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.
2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`
when using Tensor Parallelism (TP >= 8).
- See reference implementations in SLIME:
- Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
- runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
"""
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
await self.inference_engine.resume_memory_occupation()
named_tensors = params
load_format = None
for tensor_index, (name, tensor) in enumerate(named_tensors):
serialized_tensor = MultiprocessingSerializer.serialize(tensor.detach())
update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20
for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):
# On each rank, serialize a batch of (name, tensor) tuples.
# named_tensors_batch will be a list like:
# [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...]
named_tensors_batch = [
(name, MultiprocessingSerializer.serialize(tensor.detach())) for name, tensor in batch
]
if self.device_mesh["tp"].get_local_rank() == 0:
gathered_serialized_tensors = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])]
# On rank 0, prepare a list to hold the gathered batches from all ranks.
gathered_serialized_batches = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])]
else:
gathered_serialized_tensors = None
gathered_serialized_batches = None
# Gather the named_tensors_batch from all ranks to rank 0.
# After this, on rank 0, gathered_serialized_batches will be a list of lists:
# [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0
# [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1
# ... ]
# On other ranks, gathered_serialized_batches will be None.
dist.gather_object(
obj=serialized_tensor,
object_gather_list=gathered_serialized_tensors,
obj=named_tensors_batch,
object_gather_list=gathered_serialized_batches,
dst=self.device_mesh["tp"].mesh.tolist()[0],
group=self.device_mesh["tp"].get_group(),
)
if self.device_mesh["tp"].get_local_rank() == 0:
# Use zip(*) to "transpose" the data structure.
# This groups the serialized parts for each individual tensor across all TP ranks.
# Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]]
# to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ]
logical_tensors = zip(*gathered_serialized_batches, strict=False)
await self.inference_engine.update_weights_from_tensor(
named_tensors=[
# 'tensor_group' represents a single logical tensor's data from all ranks.
(
name,
LocalSerializedTensor(values=gathered_serialized_tensors),
tensor_group[0][0], # Get the name from the first rank's data.
LocalSerializedTensor(
# 'rank_part' is the (name, serialized_tensor) tuple from one specific rank.
values=[rank_part[1] for rank_part in tensor_group]
),
)
for tensor_group in logical_tensors
# each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) )
],
load_format=load_format,
flush_cache=False,
)
if self.device_mesh["tp"].get_local_rank() == 0:
await self.inference_engine.flush_cache()
if self.device_mesh["tp"].get_local_rank() == 0:
await self.inference_engine.flush_cache()
async def release_memory(self):
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: