mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
Merge branch 'volcengine:main' into recipe/async_training
This commit is contained in:
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@ -8,7 +8,7 @@
|
||||
|
||||
/third_party/sglang @zhaochenyang20 @SwordFaith
|
||||
/third_party/vllm @PeterSH6 @wuxibin89
|
||||
/verl/single_controller @zw0610 @wuxibin89
|
||||
/verl/single_controller @zw0610 @wuxibin89 @hongpeng-guo
|
||||
/verl/trainer @eric-haibin-lin @vermouth1992 @tongyx361 @PeterSH6
|
||||
/verl/workers/rollout/vllm_rollout @wuxibin89 @PeterSH6 @chenhaiq
|
||||
/verl/workers/rollout/sglang_rollout @zhaochenyang20 @SwordFaith @chenhaiq
|
||||
|
4
.github/workflows/checkpoint_converter.yml
vendored
4
.github/workflows/checkpoint_converter.yml
vendored
@ -131,6 +131,10 @@ jobs:
|
||||
run: |
|
||||
ray stop --force
|
||||
python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat --use_cpu_initialization
|
||||
- name: Running distributed Huggingface to Megatron dist_ckpt CPU converter (Qwen/Qwen1.5-MoE-A2.7B-Chat)
|
||||
run: |
|
||||
ray stop --force
|
||||
torchrun --nproc_per_node 8 --nnodes 1 scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat_dist --use_cpu_initialization
|
||||
- name: clean up
|
||||
run: |
|
||||
rm -rf checkpoints
|
||||
|
@ -139,6 +139,10 @@ jobs:
|
||||
exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal"
|
||||
python -m verl.model_merger test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface
|
||||
python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface
|
||||
- name: Test Megatron distributed checkpoints merging function (DeepSeek)
|
||||
run: |
|
||||
exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal"
|
||||
torchrun --nproc_per_node 4 --nnodes 1 -m verl.model_merger merge --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --target_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/hf_model
|
||||
- name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek)
|
||||
run: |
|
||||
ray stop --force
|
||||
|
@ -14,7 +14,7 @@ The first two types of images are hosted on dockerhub [verlai/verl](https://hub.
|
||||
|
||||
## Base Image
|
||||
|
||||
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4-te2.3``. The installed package versions can be found from tags, and the Dockerfile can be found in ``verl[version]-[packages]/Dockerfile.base``.
|
||||
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``verl[version]-[packages]/Dockerfile.base``.
|
||||
|
||||
The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions.
|
||||
|
||||
@ -76,4 +76,4 @@ pip3 install --no-deps -e .
|
||||
git clone https://github.com/volcengine/verl && cd verl
|
||||
pip3 install -e .[vllm]
|
||||
pip3 install -e .[sglang]
|
||||
```
|
||||
```
|
||||
|
@ -99,6 +99,16 @@ Example usage for merging Megatron checkpoints:
|
||||
--local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \
|
||||
--target_dir /path/to/merged_hf_model
|
||||
|
||||
Example usage for distributed merging Megatron checkpoints:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge \
|
||||
--backend megatron \
|
||||
--tie-word-embedding \
|
||||
--local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \
|
||||
--target_dir /path/to/merged_hf_model
|
||||
|
||||
Example usage for merging FSDP checkpoints:
|
||||
|
||||
.. code:: bash
|
||||
@ -145,6 +155,15 @@ Example command to convert the model is as follows:
|
||||
--use_cpu_initialization # Only work for MoE models
|
||||
|
||||
|
||||
Example command to distributed convert the huge model like deepseekv3 671B is as follows:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} scripts/converter_hf_to_mcore.py \
|
||||
--hf_model_path deepseek-ai/DeepSeek-V3 \
|
||||
--output_path /mnt/disk/deepseek-ai/DeepSeek-V3 \
|
||||
--use_cpu_initialization # Only work for MoE models
|
||||
|
||||
Original Checkpoint Utils
|
||||
-------------------------
|
||||
|
||||
|
@ -60,7 +60,7 @@ Ulysses Utilities
|
||||
--------------------
|
||||
|
||||
.. automodule:: verl.utils.ulysses
|
||||
:members: gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
|
||||
:members: gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
|
||||
|
||||
FSDP Utilities
|
||||
------------------
|
||||
|
@ -33,6 +33,7 @@ verl is fast with:
|
||||
start/multinode
|
||||
start/ray_debug_tutorial
|
||||
start/more_resources
|
||||
start/agentic_rl
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
125
docs/start/agentic_rl.rst
Normal file
125
docs/start/agentic_rl.rst
Normal file
@ -0,0 +1,125 @@
|
||||
Agentic RL Training
|
||||
===================
|
||||
|
||||
Last updated: 07/15/2025.
|
||||
|
||||
Overview
|
||||
----------
|
||||
The goal of Agentic RL is to improve the performance of backend models from reinforcement learning to the Agent. During the training process, a series of features are developed:
|
||||
|
||||
1. Server-based asynchronous rollout
|
||||
2. Multi-turn conversations and tool calls
|
||||
3. LangGraph-based Agent
|
||||
|
||||
|
||||
This document explains the system principles and usage involved to help users implement Agentic RL.
|
||||
|
||||
|
||||
Server-based Asynchronous Rollout
|
||||
---------------------------------
|
||||
|
||||
Since Agents need to interact with the environment through various tool calls, in order to avoid GPU idling while waiting for tool call return results, an asyncio based co-routing mechanism is utilized to execute each rollout requests asynchronously, thereby improving training performance. To support asynchronous rollout, the inference engine (server) and the agent (client) are architecturally separated, implementing a server-based system with the following objectives:
|
||||
|
||||
1. Enabling load balancing mechanisms to balance loads across multiple GPUs and reduce the impact of long-tail requests on performance. For this purpose, scheduling capabilities in stream mode (recipe\stream_mode) are implemented as a recipe.
|
||||
2. Preventing agent specific features such as tracing from affecting the inference engine.
|
||||
|
||||
System Architecture
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop.png?raw=true
|
||||
|
||||
System Components
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
+--------------------------+----------------------------------------------------------------------------+
|
||||
| Component | Role |
|
||||
+==========================+============================================================================+
|
||||
| AgentLoop | Client, implements Agent functions |
|
||||
+--------------------------+----------------------------------------------------------------------------+
|
||||
| AsyncLLMServerManager | Inference gateway, provides generate interface for AgentLoop |
|
||||
+--------------------------+----------------------------------------------------------------------------+
|
||||
| AsyncServer | Server, each instance is connected to one DP group of the inference engine |
|
||||
+--------------------------+----------------------------------------------------------------------------+
|
||||
|
||||
**"generate" Interface**
|
||||
|
||||
The "generate" function based on ray actor is used between the Client and Server instead of the standard chat completion API. This is because the conversion between tokens and text can be irreversible. For example, the token converted from "<think>" will be different from that generated by the LLM. During the training phase, it is necessary to strictly use the tokens generated by LLM inference to avoid inaccurate in computing advantage, which may affect model performance. Having the Server provide a token-based API helps the Client maintain the relationship between the text generated by tool calls and the tokens returned by the LLM, so as to output correct tokens for training.
|
||||
|
||||
|
||||
**Inference Engine Adaptation**
|
||||
AsyncServer uniformly provides a generate function to the upper layer, with separate implementations for SGLang and vLLM to hide underlying differences:
|
||||
|
||||
1. The SGLang AsyncServer uses the async_generate interface of the SGLang engine, which is located on the first GPU of each TP group. Therefore, AsyncServer needs to remotely call async_generate through ray actor.
|
||||
2. The vLLM AsyncServer uses the generate interface of the vLLM engine, which can communicate with the GPUs in the TP group through ZMQ and can be directly called in AsyncServer.
|
||||
|
||||
|
||||
Usage Example
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints.
|
||||
This example uses the sglang inference engine by default, and you can also modify rollout_name to use vllm.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
|
||||
|
||||
|
||||
Multi-turn Conversations and Tool Calls
|
||||
---------------------------------------
|
||||
|
||||
Follow :doc:`Multi-turn Rollout Support<../sglang_multiturn/multiturn>` to prepare tool and configuration files.
|
||||
|
||||
The Tool Agent Loop has an additional requirement: adding an "agent_name" field to the dataset. During rollout, it will choose to use tool_agent_loop or single_turn_agent (default) based on this field.
|
||||
|
||||
Usage Example
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# install mlflow to view toolcall and llm trace
|
||||
pip install mlflow
|
||||
|
||||
# This will download and preprocess the GSM8K dataset into ~/data/gsm8k/ and add the "agent_name" field.
|
||||
bash examples/data_preprocess/gsm8k_tool_agent_loop.py
|
||||
|
||||
# Start training with tool calls and enabled mlflow based trace helping to debug the rollout details
|
||||
bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh
|
||||
|
||||
# When training is done, start a mlflow server to view trace
|
||||
mlflow ui -h 0.0.0.0 -p 5000 --backend-store-uri sqlite:////tmp/mlruns.db
|
||||
|
||||
# then you can open http://<your ip address>:5000 from browser to view trace
|
||||
|
||||
|
||||
Note: During training, because the model may sometimes fail to generate correct toolcall tags, an error message "Failed to decode tool call" will be output to the console, which does not indicate an abnormality in training.
|
||||
|
||||
|
||||
Follow :doc:`Rollout trace<../advance/rollout_trace>` to known more about trace feature.
|
||||
|
||||
|
||||
|
||||
Agent Framework
|
||||
---------------
|
||||
|
||||
System Architecture
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true
|
||||
|
||||
System Components
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
+--------------------------+-----------------------------------------------------------------------------------------------+
|
||||
| Component | Role |
|
||||
+==========================+===============================================================================================+
|
||||
| ChatModel | LLM object of LangChain, used to adapt to the “generate” api provided by AsyncLLMServerManager|
|
||||
+--------------------------+-----------------------------------------------------------------------------------------------+
|
||||
| RectAgentLoop | Agent adaptation layer, which by default supports a naive LangGraph Agentic. |
|
||||
| | New classes can be derived to support user-defined Agents, and the run function needs to be |
|
||||
| | implemented to complete Agent calls. |
|
||||
+--------------------------+-----------------------------------------------------------------------------------------------+
|
||||
| AsyncServer | Server, each instance is connected to one DP group of the inference engine. |
|
||||
+--------------------------+-----------------------------------------------------------------------------------------------+
|
||||
|
||||
|
||||
Follow doc "recipe/langgraph_agent/example/README.md" for more details.
|
@ -52,7 +52,7 @@ The first two types of images are hosted on dockerhub `verlai/verl <https://hub.
|
||||
Base Image
|
||||
::::::::::
|
||||
|
||||
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4-te2.3``. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``.
|
||||
The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``.
|
||||
|
||||
The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions. From verl0.5, images are built with `Deep-EP <https://github.com/deepseek-ai/DeepEP>`_ for efficient EP communication.
|
||||
|
||||
@ -255,7 +255,7 @@ If you encounter issues about package versions during running verl, please updat
|
||||
Install with AMD GPUs - ROCM kernel support
|
||||
------------------------------------------------------------------
|
||||
|
||||
When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it.
|
||||
When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it.
|
||||
If you encounter any issues in using AMD GPUs running verl, feel free to contact me - `Yusheng Su <https://yushengsu-thu.github.io/>`_.
|
||||
|
||||
Find the docker for AMD ROCm: `docker/Dockerfile.rocm <https://github.com/volcengine/verl/blob/main/docker/Dockerfile.rocm>`_
|
||||
@ -336,6 +336,6 @@ Launch the container
|
||||
/bin/bash
|
||||
|
||||
If you do not want to root mode and require assign yourself as the user,
|
||||
Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script.
|
||||
Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script.
|
||||
|
||||
verl with AMD GPUs currently supports FSDP as the training engine, vLLM and SGLang as the inference engine. We will support Megatron in the future.
|
||||
|
117
examples/data_preprocess/gsm8k_tool_agent_loop.py
Normal file
117
examples/data_preprocess/gsm8k_tool_agent_loop.py
Normal file
@ -0,0 +1,117 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocess the GSM8k dataset to parquet format
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import datasets
|
||||
|
||||
from verl.utils.hdfs_io import copy, makedirs
|
||||
|
||||
|
||||
def extract_solution(solution_str):
|
||||
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
assert solution is not None
|
||||
final_solution = solution.group(0)
|
||||
final_solution = final_solution.split("#### ")[1].replace(",", "")
|
||||
return final_solution
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local_dir", default="~/data/gsm8k")
|
||||
parser.add_argument("--hdfs_dir", default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
data_source = "openai/gsm8k"
|
||||
dataset = datasets.load_dataset(data_source, "main")
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
test_dataset = dataset["test"]
|
||||
|
||||
instruction_following = "Let's think step by step and output the final answer after `####`."
|
||||
|
||||
# add a row to each data item that represents a unique id
|
||||
def make_map_fn(split):
|
||||
def process_fn(example, idx):
|
||||
question_raw = example.pop("question")
|
||||
|
||||
question = question_raw + " " + instruction_following
|
||||
|
||||
answer_raw = example.pop("answer")
|
||||
solution = extract_solution(answer_raw)
|
||||
data = {
|
||||
"data_source": data_source,
|
||||
"agent_name": "tool_agent",
|
||||
"prompt": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a math expert. You are given a question and you need to solve it step by step. "
|
||||
"Reasoning step by step before any tool call. "
|
||||
"You should use the `calc_gsm8k_reward` tool after step by step solving the question, "
|
||||
"before generate final answer at least once and refine your answer if necessary. "
|
||||
"Put your final answer in the format of `#### <answer>`."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
},
|
||||
],
|
||||
"ability": "math",
|
||||
"reward_model": {"style": "rule", "ground_truth": solution},
|
||||
"extra_info": {
|
||||
"split": split,
|
||||
"index": idx,
|
||||
"answer": answer_raw,
|
||||
"question": question_raw,
|
||||
"need_tools_kwargs": True,
|
||||
"tools_kwargs": {
|
||||
"calc_gsm8k_reward": {
|
||||
"create_kwargs": {"ground_truth": solution},
|
||||
# "execute_kwargs": {},
|
||||
# "calc_reward_kwargs": {},
|
||||
# "release_kwargs": {},
|
||||
},
|
||||
},
|
||||
"interaction_kwargs": {
|
||||
"query": question,
|
||||
"ground_truth": solution,
|
||||
},
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
return process_fn
|
||||
|
||||
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
|
||||
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
|
||||
|
||||
local_dir = args.local_dir
|
||||
hdfs_dir = args.hdfs_dir
|
||||
|
||||
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
|
||||
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
|
||||
|
||||
if hdfs_dir is not None:
|
||||
makedirs(hdfs_dir)
|
||||
copy(src=local_dir, dst=hdfs_dir)
|
@ -19,7 +19,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
|
@ -18,7 +18,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
|
@ -25,7 +25,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
|
||||
|
@ -49,5 +49,6 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.train_files=$HOME/data/gsm8k/train.parquet \
|
||||
data.val_files=$HOME/data/gsm8k/test.parquet \
|
||||
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
|
||||
trainer.total_epochs=15 $@
|
||||
trainer.total_epochs=15 \
|
||||
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@
|
||||
|
||||
|
@ -0,0 +1,57 @@
|
||||
# run on 8xH100
|
||||
# make sure your current working directory is the root of the project
|
||||
|
||||
set -x
|
||||
|
||||
ulimit -n 65535
|
||||
|
||||
PROJECT_DIR="$(pwd)"
|
||||
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
--config-path="$CONFIG_PATH" \
|
||||
--config-name='gsm8k_multiturn_grpo' \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_batch_size=256 \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.mode=async \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
||||
actor_rollout_ref.rollout.n=16 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
actor_rollout_ref.rollout.trace.backend=mlflow \
|
||||
actor_rollout_ref.rollout.trace.token2text=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger='["console","mlflow"]' \
|
||||
trainer.project_name='gsm8k_tool-agent' \
|
||||
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-tool-agent-verify-n16' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.test_freq=20 \
|
||||
trainer.total_training_steps=2 \
|
||||
data.train_files=$HOME/data/gsm8k/train.parquet \
|
||||
data.val_files=$HOME/data/gsm8k/test.parquet \
|
||||
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
|
||||
trainer.total_epochs=15 $@
|
||||
|
13
recipe/langgraph_agent/__init__.py
Normal file
13
recipe/langgraph_agent/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
357
recipe/langgraph_agent/chat_model.py
Normal file
357
recipe/langgraph_agent/chat_model.py
Normal file
@ -0,0 +1,357 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Ref: https://python.langchain.com/docs/how_to/custom_chat_model/
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.base import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
convert_to_openai_messages,
|
||||
)
|
||||
from langchain_core.messages.tool import InvalidToolCall, ToolCall
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import Field
|
||||
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
|
||||
from verl.experimental.agent_loop.tool_parser import ToolParser
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class MaxTokenExceededError(Exception):
|
||||
"""Indicate that history chat messages + tool message exceeds LLM max_tokens."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChatModel(BaseChatModel):
|
||||
model_name: str = Field(alias="model")
|
||||
"""The name of the model"""
|
||||
|
||||
client: AsyncLLMServerManager
|
||||
"""AsyncLLM server manager"""
|
||||
|
||||
tokenizer: Any
|
||||
"""Tokenizer for the model"""
|
||||
|
||||
max_tokens: int
|
||||
"""Max tokens to generate"""
|
||||
|
||||
tool_parser: str = "hermes"
|
||||
"""Tool parser for the model"""
|
||||
|
||||
max_parallel_calls: int = 1
|
||||
"""Max parallel tool calls"""
|
||||
|
||||
temperature: float = 1.0
|
||||
"""Temperature for sampling"""
|
||||
|
||||
top_p: float = 1.0
|
||||
"""Top p for sampling"""
|
||||
|
||||
repetition_penalty: float = 1.0
|
||||
"""Repetition penalty for sampling"""
|
||||
|
||||
def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tools to the model.
|
||||
|
||||
Args:
|
||||
tools: Sequence of tools to bind to the model.
|
||||
|
||||
Returns:
|
||||
A Runnable that returns a message.
|
||||
"""
|
||||
formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]
|
||||
|
||||
# used to remove system prompt prefix when encoding tool response
|
||||
system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
|
||||
kwargs["system_prompt"] = system_prompt
|
||||
|
||||
return self.bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: dict | type,
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, dict | BaseChatModel]:
|
||||
"""Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Asynchronously generate chat completion message.
|
||||
|
||||
Args:
|
||||
messages (list[BaseMessage]): List of list of messages.
|
||||
stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ChatResult: Chat result.
|
||||
"""
|
||||
request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)
|
||||
|
||||
sampling_params = {
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
}
|
||||
if "sampling_params" in kwargs:
|
||||
sampling_params.update(kwargs["sampling_params"])
|
||||
|
||||
response_ids = await self.client.generate(
|
||||
request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Get the type of language model used by this chat model."""
|
||||
return self.model_name
|
||||
|
||||
async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:
|
||||
"""Preprocess messages for chat completion.
|
||||
|
||||
To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out
|
||||
instead of messages list.
|
||||
|
||||
But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory
|
||||
(prompt_ids, response_mask) in lastest AIMessage.response_metadata.
|
||||
|
||||
1. Encode ToolMessage to token ids.
|
||||
2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.
|
||||
3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.
|
||||
|
||||
Ref: https://python.langchain.com/docs/concepts/chat_history/
|
||||
|
||||
Args:
|
||||
messages (list[BaseMessage]): List of messages.
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.
|
||||
"""
|
||||
# messages: [system], human, ai, human|tool, ai, human|tool, ...
|
||||
assert messages[-1].type in ["human", "tool"], (
|
||||
f"Last message must be human or tool, but got {messages[-1].type}"
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Case 1: initial chat completion: [system], human
|
||||
if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"):
|
||||
prompt_ids = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.tokenizer.apply_chat_template(
|
||||
convert_to_openai_messages(messages),
|
||||
tools=kwargs.get("tools"),
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
),
|
||||
)
|
||||
return str(uuid.uuid4()), prompt_ids, []
|
||||
|
||||
# Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].type == "ai":
|
||||
break
|
||||
assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata"
|
||||
assert "response_mask" in messages[i].response_metadata, (
|
||||
"Last message must have response_mask in response_metadata"
|
||||
)
|
||||
|
||||
# encode tool response
|
||||
tool_responses = convert_to_openai_messages(messages[i + 1 :])
|
||||
tool_response_ids = await loop.run_in_executor(
|
||||
None,
|
||||
lambda messages=tool_responses: self.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True
|
||||
),
|
||||
)
|
||||
tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
|
||||
|
||||
# stop generation if response length exceeds max response length
|
||||
if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
|
||||
raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded")
|
||||
|
||||
# append tool response to prompt
|
||||
request_id = messages[i].response_metadata.pop("request_id")
|
||||
prompt_ids = messages[i].response_metadata.pop("prompt_ids")
|
||||
response_mask = messages[i].response_metadata.pop("response_mask")
|
||||
prompt_ids += tool_response_ids
|
||||
response_mask += [0] * len(tool_response_ids)
|
||||
|
||||
return request_id, prompt_ids, response_mask
|
||||
|
||||
async def _postprocess(
|
||||
self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any
|
||||
) -> AIMessage:
|
||||
"""Postprocess response_ids when chat completion is done.
|
||||
|
||||
1. Decode response_ids, parse tool calls to AIMessage.
|
||||
2. Append response_ids to prompt_ids, and append 1 to response_mask.
|
||||
3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.
|
||||
|
||||
Args:
|
||||
request_id (str): Unique request id.
|
||||
prompt_ids (list[int]): Input prompt token ids in this chat completion.
|
||||
response_mask (list[int]): Response mask before this chat completion.
|
||||
response_ids (list[int]): LLM generated token ids in this chat completion.
|
||||
|
||||
Returns:
|
||||
AIMessage: Postprocessed message.
|
||||
"""
|
||||
prompt_ids += response_ids
|
||||
response_mask += [1] * len(response_ids)
|
||||
|
||||
tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)
|
||||
content, function_calls = await tool_parser.extract_tool_calls(response_ids)
|
||||
|
||||
tool_calls, invalid_tool_calls = [], []
|
||||
for function_call in function_calls:
|
||||
try:
|
||||
args = json.loads(function_call.arguments)
|
||||
if not isinstance(args, dict):
|
||||
raise json.JSONDecodeError(f"Invalid json tool arguments: {args}")
|
||||
tool_call = ToolCall(
|
||||
args=args,
|
||||
name=function_call.name,
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid json tool arguments: {e}")
|
||||
tool_call = InvalidToolCall(
|
||||
args=function_call.arguments,
|
||||
name=function_call.name,
|
||||
error=f"Invalid json tool arguments: {e}",
|
||||
)
|
||||
invalid_tool_calls.append(tool_call)
|
||||
|
||||
message = AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls[: self.max_parallel_calls],
|
||||
invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],
|
||||
response_metadata={
|
||||
"request_id": request_id,
|
||||
"prompt_ids": prompt_ids,
|
||||
"response_mask": response_mask,
|
||||
},
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
class TruncateStructuredTool(StructuredTool):
|
||||
"""Structured tool with response truncation."""
|
||||
|
||||
tool_response_truncate_side: str
|
||||
"""truncate side of tool response: left, middle, right"""
|
||||
|
||||
max_tool_response_length: int
|
||||
"""max length of tool response"""
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
tool_response = await super()._arun(*args, config=config, **kwargs)
|
||||
tool_response = str(tool_response)
|
||||
|
||||
if len(tool_response) > self.max_tool_response_length:
|
||||
if self.tool_response_truncate_side == "left":
|
||||
tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
|
||||
elif self.tool_response_truncate_side == "right":
|
||||
tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
|
||||
else:
|
||||
length = self.max_tool_response_length // 2
|
||||
tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]
|
||||
|
||||
return tool_response
|
||||
|
||||
|
||||
def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:
|
||||
"""Convert messages to AgentLoopOutput.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of messages, last message must be assistant
|
||||
with response_metadata containing `prompt_ids` and `response_mask`.
|
||||
response_length (int): Max length of response.
|
||||
|
||||
Returns:
|
||||
AgentLoopOutput: agent loop output trajectory used for training.
|
||||
"""
|
||||
# skip last tool calls
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].type != "tool":
|
||||
break
|
||||
last_message = messages[i]
|
||||
assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}"
|
||||
assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata"
|
||||
assert "response_mask" in last_message.response_metadata, (
|
||||
"Last message must have response_mask in response_metadata"
|
||||
)
|
||||
|
||||
num_turns = 0
|
||||
for i in range(len(messages)):
|
||||
if messages[i].type == "system":
|
||||
continue
|
||||
# parallel tool calls are in single turn
|
||||
if i == 0 or messages[i].type != messages[i - 1].type:
|
||||
num_turns += 1
|
||||
|
||||
prompt_ids = last_message.response_metadata["prompt_ids"]
|
||||
response_mask = last_message.response_metadata["response_mask"]
|
||||
|
||||
response_ids = prompt_ids[-len(response_mask) :]
|
||||
prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
|
||||
|
||||
output = AgentLoopOutput(
|
||||
prompt_ids=prompt_ids,
|
||||
response_ids=response_ids[:response_length],
|
||||
response_mask=response_mask[:response_length],
|
||||
num_turns=num_turns,
|
||||
metrics={},
|
||||
)
|
||||
return output
|
111
recipe/langgraph_agent/example/README.md
Normal file
111
recipe/langgraph_agent/example/README.md
Normal file
@ -0,0 +1,111 @@
|
||||
# MathExpression: LangGraph Agent Example
|
||||
|
||||
MathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/).
|
||||
|
||||
### Define react agent with tool
|
||||
Firstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`:
|
||||
```python
|
||||
@tool(parse_docstring=True)
|
||||
def calculate(a: int, b: int, operand: str) -> int:
|
||||
"""
|
||||
Compute the results using operand with two integers
|
||||
|
||||
Args:
|
||||
a: the first operand
|
||||
b: the second operand
|
||||
operand: '+' or '-' or '*' or '@'
|
||||
"""
|
||||
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
|
||||
if operand == "@":
|
||||
return 3 * a - 2 * b
|
||||
return eval(f"{a} {operand} {b}")
|
||||
```
|
||||
|
||||
Without calling `calculate`, ReactAgent is impossible to evaluate math expression correctly.
|
||||
|
||||
Then, we can equip ReactAgent with `calculate` tool:
|
||||
```python
|
||||
class MathExpressionReactAgentLoop(ReactAgentLoop):
|
||||
@classmethod
|
||||
def init_class(cls, config, tokenizer):
|
||||
cls.tools = [calculate]
|
||||
super().init_class(config, tokenizer)
|
||||
```
|
||||
|
||||
We can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class.
|
||||
```yaml
|
||||
- name: math_expression
|
||||
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
|
||||
```
|
||||
|
||||
### Prepare dataset
|
||||
Now, let's prepare two small datasets for training and evaluation:
|
||||
```bash
|
||||
python recipe/langgraph_agent/example/create_dataset.py
|
||||
```
|
||||
|
||||
Note that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the
|
||||
agent loop class.
|
||||
| prompt | reward_model | agent_name |
|
||||
|--------------------------------------|------------------------------|-----------------|
|
||||
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
|
||||
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
|
||||
|
||||
Generated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions.
|
||||
```
|
||||
(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6
|
||||
(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6
|
||||
9 @ 4
|
||||
((2 @ 2) @ (3 @ 3)) @ 4
|
||||
```
|
||||
|
||||
### Training
|
||||
Hook all these up and start training:
|
||||
```bash
|
||||
bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log
|
||||
```
|
||||
|
||||
After total 39 steps, model should achieve 100% accuray on test dataset:
|
||||
- val-aux/lighteval/MATH/reward: 1.0
|
||||
- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns.
|
||||
|
||||
We can see model can solve math expression by calling `calculate` tool:
|
||||
```
|
||||
system
|
||||
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
|
||||
user
|
||||
We define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \boxed{} only in the last turn. Now answer the following questions:
|
||||
Compute 5 @ (3 @ 6 @ 8)
|
||||
assistant
|
||||
Sure, let's break this down step by step. We need to compute the expression \(5 @ (3 @ 6 @ 8)\). First, we need to compute the inner expression \(3 @ 6 @ 8\).
|
||||
|
||||
Let's compute \(3 @ 6\) first using the external tool.
|
||||
<tool_call>
|
||||
{"name": "calculate", "arguments": {"a": 3, "b": 6, "operand": "@"}}
|
||||
</tool_call>user
|
||||
<tool_response>
|
||||
-3
|
||||
</tool_response>
|
||||
assistant
|
||||
We have computed that \(3 @ 6 = -3\). Now we need to compute \(-3 @ 8\).
|
||||
|
||||
Let's use the external tool to compute this.
|
||||
<tool_call>
|
||||
{"name": "calculate", "arguments": {"a": -3, "b": 8, "operand": "@"}}
|
||||
</tool_call>user
|
||||
<tool_response>
|
||||
-25
|
||||
</tool_response>
|
||||
assistant
|
||||
We have computed that \(-3 @ 8 = -25\). Now we can put this result into the original expression to get the final result of \(5 @ (3 @ 6 @ 8)\), which is \(5 @ (-25)\).
|
||||
|
||||
Let's compute \(5 @ (-25)\) using the external tool.
|
||||
<tool_call>
|
||||
{"name": "calculate", "arguments": {"a": 5, "b": -25, "operand": "@"}}
|
||||
</tool_call>user
|
||||
<tool_response>
|
||||
65
|
||||
</tool_response>
|
||||
assistant
|
||||
The final result of \(5 @ (3 @ 6 @ 8)\) is \(\boxed{65}\).
|
||||
```
|
2
recipe/langgraph_agent/example/agent.yaml
Normal file
2
recipe/langgraph_agent/example/agent.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
- name: math_expression
|
||||
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
|
277
recipe/langgraph_agent/example/create_dataset.py
Normal file
277
recipe/langgraph_agent/example/create_dataset.py
Normal file
@ -0,0 +1,277 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Create dataset for calculator
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2):
|
||||
"""
|
||||
Generate a random mathematical expression with operators +, -, *, /, and parentheses.
|
||||
|
||||
Args:
|
||||
min_terms (int): Minimum number of terms in the expression.
|
||||
max_terms (int): Maximum number of terms in the expression.
|
||||
max_number (int): Maximum value for numbers in the expression.
|
||||
allow_decimals (bool): Whether to allow decimal numbers.
|
||||
max_depth (int): Maximum nesting depth for parentheses.
|
||||
|
||||
Returns:
|
||||
str: A valid mathematical expression as a string.
|
||||
"""
|
||||
|
||||
def generate_number():
|
||||
"""Generate a random number (integer or float)."""
|
||||
assert min_number < max_number
|
||||
num = random.uniform(min_number, max_number)
|
||||
if not allow_decimals:
|
||||
num = int(num)
|
||||
else:
|
||||
num = round(num, random.randint(0, 2)) # Round to 0-2 decimal places
|
||||
return str(num)
|
||||
|
||||
def generate_term(depth=0):
|
||||
"""Generate a term (number or parenthesized expression)."""
|
||||
if depth < max_depth and random.random() < 0.5: # 50% chance to add parentheses
|
||||
expr = generate_expression(depth + 1)
|
||||
return f"({expr})"
|
||||
else:
|
||||
return generate_number()
|
||||
|
||||
def generate_expression(depth=0):
|
||||
"""Generate a full expression with multiple terms and operators."""
|
||||
num_terms = random.randint(min_terms, max_terms)
|
||||
terms = [generate_term(depth) for _ in range(num_terms)]
|
||||
|
||||
# Randomly select operators
|
||||
operators = ["+", "-", "*", "/", "@"]
|
||||
expr = terms[0]
|
||||
|
||||
for i in range(1, num_terms):
|
||||
# Bias towards + and - for readability
|
||||
op = random.choices(
|
||||
operators,
|
||||
weights=[0, 0, 0, 0, 1], # + and - are 1.5x more likely than * and /
|
||||
)[0]
|
||||
expr += f" {op} " + terms[i]
|
||||
|
||||
return expr
|
||||
|
||||
return generate_expression()
|
||||
|
||||
|
||||
def test():
|
||||
# Example 1: Basic integer expression
|
||||
print(generate_math_expression())
|
||||
# Output: (3 + 7) * 2 - 5
|
||||
|
||||
# Example 2: Expression with decimals
|
||||
print(generate_math_expression(allow_decimals=True))
|
||||
# Output: 4.5 / (2.1 + 3.7) - 1.2
|
||||
|
||||
# Example 3: More complex expression with higher depth
|
||||
print(generate_math_expression(max_terms=6, max_depth=3))
|
||||
# Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4
|
||||
|
||||
# Example 4: Simplified expression
|
||||
print(generate_math_expression(min_terms=2, max_terms=3, max_number=5))
|
||||
# Output: 4 - 2 * 3
|
||||
|
||||
|
||||
def calculate(expression: str) -> float:
|
||||
"""
|
||||
Evaluate a mathematical expression with +, -, *, /, @, and parentheses.
|
||||
The @ operator is defined as: a @ b = 3a - 2b.
|
||||
|
||||
Args:
|
||||
expression (str): Input mathematical expression (e.g., "3@2+4").
|
||||
|
||||
Returns:
|
||||
float: Result of the evaluated expression.
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero).
|
||||
"""
|
||||
|
||||
def tokenize(s: str) -> list:
|
||||
"""Convert the input string into tokens (numbers, operators, parentheses)."""
|
||||
tokens = []
|
||||
i = 0
|
||||
while i < len(s):
|
||||
if s[i].isdigit() or s[i] == ".":
|
||||
# Parse number (integer or float)
|
||||
j = i
|
||||
while j < len(s) and (s[j].isdigit() or s[j] == "."):
|
||||
j += 1
|
||||
tokens.append(s[i:j])
|
||||
i = j
|
||||
elif s[i] in "+-*/@()":
|
||||
# Operator or parenthesis
|
||||
tokens.append(s[i])
|
||||
i += 1
|
||||
elif s[i].isspace():
|
||||
# Skip whitespace
|
||||
i += 1
|
||||
else:
|
||||
raise ValueError(f"Invalid character: {s[i]}")
|
||||
return tokens
|
||||
|
||||
def infix_to_postfix(tokens: list) -> list:
|
||||
"""Convert infix notation to postfix notation (Reverse Polish Notation)."""
|
||||
output = []
|
||||
stack = []
|
||||
# Higher precedence for @ (between * and +)
|
||||
precedence = {"@": 3, "*": 2, "/": 2, "+": 1, "-": 1}
|
||||
|
||||
for token in tokens:
|
||||
if token.isdigit() or "." in token:
|
||||
output.append(token)
|
||||
elif token == "(":
|
||||
stack.append(token)
|
||||
elif token == ")":
|
||||
while stack and stack[-1] != "(":
|
||||
output.append(stack.pop())
|
||||
if not stack or stack[-1] != "(":
|
||||
raise ValueError("Mismatched parentheses")
|
||||
stack.pop() # Discard '('
|
||||
else: # Operator
|
||||
while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence.get(token, 0):
|
||||
output.append(stack.pop())
|
||||
stack.append(token)
|
||||
|
||||
# Pop remaining operators
|
||||
while stack:
|
||||
if stack[-1] in "()":
|
||||
raise ValueError("Mismatched parentheses")
|
||||
output.append(stack.pop())
|
||||
|
||||
return output
|
||||
|
||||
def evaluate_postfix(postfix: list) -> float:
|
||||
"""Evaluate postfix expression using a stack."""
|
||||
stack = []
|
||||
for token in postfix:
|
||||
if token.isdigit() or "." in token:
|
||||
stack.append(float(token))
|
||||
else:
|
||||
if len(stack) < 2:
|
||||
raise ValueError("Invalid expression")
|
||||
b = stack.pop()
|
||||
a = stack.pop()
|
||||
if token == "+":
|
||||
res = a + b
|
||||
elif token == "-":
|
||||
res = a - b
|
||||
elif token == "*":
|
||||
res = a * b
|
||||
elif token == "/":
|
||||
if b == 0:
|
||||
raise ValueError("Division by zero")
|
||||
res = a / b
|
||||
elif token == "@":
|
||||
res = 3 * a - 2 * b # Custom @ operator implementation
|
||||
else:
|
||||
raise ValueError(f"Invalid operator: {token}")
|
||||
stack.append(res)
|
||||
|
||||
if len(stack) != 1:
|
||||
raise ValueError("Invalid expression")
|
||||
return stack[0]
|
||||
|
||||
# Remove spaces and validate parentheses
|
||||
expression = expression.replace(" ", "")
|
||||
if expression.count("(") != expression.count(")"):
|
||||
raise ValueError("Mismatched parentheses")
|
||||
|
||||
tokens = tokenize(expression)
|
||||
postfix = infix_to_postfix(tokens)
|
||||
result = evaluate_postfix(postfix)
|
||||
|
||||
# Convert integers to integer representation
|
||||
if result.is_integer():
|
||||
return int(result)
|
||||
return result
|
||||
|
||||
|
||||
def generate_data(total_num_dataset, split):
|
||||
rl_dataset = {
|
||||
"prompt": [],
|
||||
"data_source": [],
|
||||
"ability": [],
|
||||
"reward_model": [],
|
||||
"extra_info": [],
|
||||
"agent_name": [],
|
||||
}
|
||||
|
||||
for idx in range(total_num_dataset):
|
||||
while True:
|
||||
try:
|
||||
expression: str = generate_math_expression(
|
||||
min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1
|
||||
)
|
||||
|
||||
num_plus = expression.count("+")
|
||||
num_minus = expression.count("-")
|
||||
num_mul = expression.count("*")
|
||||
num_star = expression.count("@")
|
||||
|
||||
answer = str(calculate(expression))
|
||||
# answer = str(eval(expression))
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
num_tool_calls = num_plus + num_minus + num_mul + num_star
|
||||
|
||||
prompt = (
|
||||
f"We define a new math operator @, where you can only call an external tool to compute. "
|
||||
f"Please put your final answer inside \\boxed{{}} only in the last turn. Now answer the "
|
||||
f"following questions:\nCompute {expression}"
|
||||
)
|
||||
prompt_with_template = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
rl_dataset["prompt"].append(prompt_with_template)
|
||||
rl_dataset["data_source"].append("lighteval/MATH")
|
||||
rl_dataset["ability"].append("math")
|
||||
rl_dataset["reward_model"].append({"style": "lighteval/MATH", "ground_truth": answer})
|
||||
rl_dataset["extra_info"].append(
|
||||
{"index": idx, "expression": expression, "split": split, "expected_tool_calls": num_tool_calls}
|
||||
)
|
||||
rl_dataset["agent_name"].append("math_expression")
|
||||
|
||||
rl_dataset = pd.DataFrame(data=rl_dataset)
|
||||
return rl_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# print(calculate("3@2")) # Output: 5 (3*3 - 2*2)
|
||||
# print(calculate("3@2+4")) # Output: 9 (5 + 4)
|
||||
# print(calculate("3*(4@2)")) # Output: 24 (3 * 8)
|
||||
# print(calculate("(5@3)*2")) # Output: 18 (9 * 2)
|
||||
|
||||
train_dataset = generate_data(total_num_dataset=5000, split="train")
|
||||
test_dataset = generate_data(total_num_dataset=500, split="test")
|
||||
|
||||
train_dataset.to_parquet("train.parquet")
|
||||
test_dataset.to_parquet("test.parquet")
|
39
recipe/langgraph_agent/example/math_expression.py
Normal file
39
recipe/langgraph_agent/example/math_expression.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
|
||||
|
||||
|
||||
@tool(parse_docstring=True)
|
||||
def calculate(a: int, b: int, operand: str) -> int:
|
||||
"""
|
||||
Compute the results using operand with two integers
|
||||
|
||||
Args:
|
||||
a: the first operand
|
||||
b: the second operand
|
||||
operand: '+' or '-' or '*' or '@'
|
||||
"""
|
||||
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
|
||||
if operand == "@":
|
||||
return 3 * a - 2 * b
|
||||
return eval(f"{a} {operand} {b}")
|
||||
|
||||
|
||||
class MathExpressionReactAgentLoop(ReactAgentLoop):
|
||||
@classmethod
|
||||
def init_class(cls, config, tokenizer, **kwargs):
|
||||
cls.tools = [calculate]
|
||||
super().init_class(config, tokenizer)
|
99
recipe/langgraph_agent/example/run_qwen2.5_3b.sh
Normal file
99
recipe/langgraph_agent/example/run_qwen2.5_3b.sh
Normal file
@ -0,0 +1,99 @@
|
||||
set -x
|
||||
|
||||
# ================= data/model/tool =================
|
||||
HDFS_ROOT=${HDFS_ROOT:-$PWD}
|
||||
DATA_ROOT=${DATA_ROOT:-$PWD}
|
||||
|
||||
model_path=$DATA_ROOT/model/Qwen2.5-3B-Instruct
|
||||
|
||||
train_files=$DATA_ROOT/dataset/math_expression_tool/train.parquet
|
||||
test_files=$DATA_ROOT/dataset/math_expression_tool/test.parquet
|
||||
|
||||
# agent
|
||||
agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml
|
||||
|
||||
# wandb
|
||||
project_name=math_expression_tool
|
||||
experiment_name=qwen2.5-3b
|
||||
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
|
||||
|
||||
# ================= algorithm =================
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
max_turns=8
|
||||
max_prompt_length=1024
|
||||
max_response_length=2048
|
||||
actor_lr=1e-6
|
||||
|
||||
train_batch_size=128
|
||||
ppo_mini_batch_size=16
|
||||
n_resp_per_prompt=8
|
||||
n_resp_per_prompt_val=1
|
||||
|
||||
# ================= perfomance =================
|
||||
infer_tp=2 # vllm
|
||||
train_sp=4 # train
|
||||
offload=True
|
||||
|
||||
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
|
||||
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=$adv_estimator \
|
||||
algorithm.use_kl_in_reward=$use_kl_in_reward \
|
||||
algorithm.kl_ctrl.kl_coef=$kl_coef \
|
||||
data.train_files="$train_files" \
|
||||
data.val_files="$test_files" \
|
||||
data.return_raw_chat=True \
|
||||
data.train_batch_size=$train_batch_size \
|
||||
data.max_prompt_length=$max_prompt_length \
|
||||
data.max_response_length=$max_response_length \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.path=$model_path \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
|
||||
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
|
||||
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
|
||||
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.actor.optim.lr=$actor_lr \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.mode=async \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
|
||||
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
|
||||
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
|
||||
actor_rollout_ref.rollout.multi_turn.format=hermes \
|
||||
actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
|
||||
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name=$project_name \
|
||||
trainer.experiment_name=$experiment_name \
|
||||
trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \
|
||||
trainer.val_before_train=True \
|
||||
trainer.log_val_generations=50 \
|
||||
trainer.nnodes=$ARNOLD_WORKER_NUM \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir=$default_local_dir \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=1 $@
|
133
recipe/langgraph_agent/react_agent_loop.py
Normal file
133
recipe/langgraph_agent/react_agent_loop.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
LangGraph React Agent Loop.
|
||||
|
||||
This implementation is exact same as `ToolAgentLoop`.
|
||||
|
||||
Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/
|
||||
"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from recipe.langgraph_agent.chat_model import (
|
||||
ChatModel,
|
||||
MaxTokenExceededError,
|
||||
convert_to_agent_output,
|
||||
)
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
|
||||
|
||||
|
||||
async def call_model(state: MessagesState, config: RunnableConfig):
|
||||
model = config["configurable"]["model"]
|
||||
sampling_params = config["configurable"]["sampling_params"]
|
||||
try:
|
||||
message = await model.ainvoke(state["messages"], sampling_params=sampling_params)
|
||||
return {"messages": [message]}
|
||||
except MaxTokenExceededError:
|
||||
# last message is ToolMessage
|
||||
return {"messages": []}
|
||||
|
||||
|
||||
def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]:
|
||||
max_assistant_turns = config["configurable"]["max_assistant_turns"]
|
||||
num_assistant_turns = 0
|
||||
for message in state["messages"]:
|
||||
if message.type == "ai":
|
||||
num_assistant_turns += 1
|
||||
|
||||
last_message = state["messages"][-1]
|
||||
|
||||
# LLM call failed, e.g: max response length exceeded
|
||||
if last_message.type == "tool":
|
||||
return END
|
||||
|
||||
# max assistant turns exceeded
|
||||
if max_assistant_turns and num_assistant_turns >= max_assistant_turns:
|
||||
return END
|
||||
|
||||
# no tool calls
|
||||
if not last_message.tool_calls:
|
||||
return END
|
||||
|
||||
return "tools"
|
||||
|
||||
|
||||
class ReactAgentLoop(AgentLoopBase):
|
||||
@classmethod
|
||||
def init_class(cls, config, tokenizer, **kwargs):
|
||||
if cls._class_initialized:
|
||||
return
|
||||
cls._class_initialized = True
|
||||
print("Performing class-level ReactAgentLoop initialization")
|
||||
|
||||
# build graph
|
||||
cls.graph = cls.build_graph()
|
||||
|
||||
@classmethod
|
||||
def build_graph(cls) -> StateGraph:
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
workflow.add_node("agent", call_model)
|
||||
workflow.add_node("tools", ToolNode(cls.tools))
|
||||
workflow.set_entry_point("agent")
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
should_continue,
|
||||
{
|
||||
"tools": "tools",
|
||||
END: END,
|
||||
},
|
||||
)
|
||||
|
||||
workflow.add_edge("tools", "agent")
|
||||
graph = workflow.compile()
|
||||
return graph
|
||||
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
model_path = self.config.actor_rollout_ref.model.path
|
||||
model_name = "/".join(model_path.split("/")[-2:])
|
||||
|
||||
rollout = self.config.actor_rollout_ref.rollout
|
||||
model = ChatModel(
|
||||
model=model_name,
|
||||
client=self.server_manager,
|
||||
tokenizer=self.tokenizer,
|
||||
max_tokens=rollout.response_length,
|
||||
max_parallel_calls=rollout.multi_turn.max_parallel_calls,
|
||||
tool_parser=rollout.multi_turn.format,
|
||||
)
|
||||
|
||||
model = model.bind_tools(self.tools, tool_choice="any")
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"model": model,
|
||||
"sampling_params": sampling_params,
|
||||
"max_user_turns": rollout.multi_turn.max_user_turns,
|
||||
"max_assistant_turns": rollout.multi_turn.max_assistant_turns,
|
||||
}
|
||||
}
|
||||
|
||||
# TODO: how to handle multiple trajectories in an graph invocation?
|
||||
# Each graph node may has its own LLM calls and state, e.g:
|
||||
# https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart
|
||||
state = await self.graph.ainvoke(input={"messages": messages}, config=config)
|
||||
|
||||
output = convert_to_agent_output(state["messages"], rollout.response_length)
|
||||
return output
|
199
recipe/langgraph_agent/test_react_agent_loop.py
Normal file
199
recipe/langgraph_agent/test_react_agent_loop.py
Normal file
@ -0,0 +1,199 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import ray
|
||||
from langchain_core.tools import tool
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
|
||||
from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
|
||||
from verl.protocol import DataProto
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_config() -> DictConfig:
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
||||
config = compose(config_name="ppo_trainer")
|
||||
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
config.actor_rollout_ref.model.path = model_path
|
||||
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
||||
config.actor_rollout_ref.rollout.mode = "async"
|
||||
config.actor_rollout_ref.rollout.prompt_length = 4096
|
||||
config.actor_rollout_ref.rollout.response_length = 4096
|
||||
config.actor_rollout_ref.rollout.n = 4
|
||||
config.actor_rollout_ref.rollout.agent.num_workers = 2
|
||||
|
||||
# test sleep/wake_up with fsdp offload
|
||||
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
|
||||
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@tool(parse_docstring=True)
|
||||
def get_current_temperature(location: str, unit: str = "celsius"):
|
||||
"""Get current temperature at a location.
|
||||
|
||||
Args:
|
||||
location: The location to get the temperature for, in the format "City, State, Country".
|
||||
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
|
||||
|
||||
Returns:
|
||||
the temperature, the location, and the unit in a dict
|
||||
"""
|
||||
print(f"[DEBUG] get_current_temperature: {location}, {unit}")
|
||||
return {
|
||||
"temperature": 26.1,
|
||||
"location": location,
|
||||
"unit": unit,
|
||||
}
|
||||
|
||||
|
||||
@tool(parse_docstring=True)
|
||||
def get_temperature_date(location: str, date: str, unit: str = "celsius"):
|
||||
"""Get temperature at a location and date.
|
||||
|
||||
Args:
|
||||
location: The location to get the temperature for, in the format "City, State, Country".
|
||||
date: The date to get the temperature for, in the format "Year-Month-Day".
|
||||
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
|
||||
|
||||
Returns:
|
||||
the temperature, the location, the date and the unit in a dict
|
||||
"""
|
||||
print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
|
||||
return {
|
||||
"temperature": 25.9,
|
||||
"location": location,
|
||||
"date": date,
|
||||
"unit": unit,
|
||||
}
|
||||
|
||||
|
||||
class TestReactAgentLoop(ReactAgentLoop):
|
||||
@classmethod
|
||||
def init_class(cls, config, tokenizer, **kwargs):
|
||||
# TODO: find better way to configure tools
|
||||
cls.tools = [get_current_temperature, get_temperature_date]
|
||||
super().init_class(config, tokenizer, **kwargs)
|
||||
|
||||
|
||||
def test_react_agent(init_config):
|
||||
ray.init(
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"VLLM_LOGGING_LEVEL": "INFO",
|
||||
"VLLM_USE_V1": "1",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# =========================== 1. Init rollout manager ===========================
|
||||
agent_loop_config = [
|
||||
{
|
||||
"_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop",
|
||||
"name": "react_agent",
|
||||
},
|
||||
]
|
||||
agent_loop_config_path = "/tmp/agent_loop_config.json"
|
||||
with open(agent_loop_config_path, "w") as f:
|
||||
json.dump(agent_loop_config, f)
|
||||
|
||||
n = 2
|
||||
init_config.actor_rollout_ref.rollout.n = n
|
||||
# init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
|
||||
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
|
||||
init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path
|
||||
agent_loop_manager = init_agent_loop_manager(init_config)
|
||||
|
||||
# =========================== 2. Generate sequences ===========================
|
||||
raw_prompts = [
|
||||
[
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "What's the temperature in Los Angeles now?"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "What's the temperature in New York now?"},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
|
||||
"Current Date: 2024-09-30",
|
||||
},
|
||||
{"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
|
||||
],
|
||||
]
|
||||
batch = DataProto(
|
||||
non_tensor_batch={
|
||||
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
|
||||
"agent_name": np.array(["react_agent"] * len(raw_prompts)),
|
||||
},
|
||||
)
|
||||
batch = batch.repeat(n)
|
||||
result = agent_loop_manager.generate_sequences(prompts=batch)
|
||||
assert len(result) == len(raw_prompts) * n
|
||||
|
||||
# Check turns
|
||||
num_turns = result.non_tensor_batch["__num_turns__"]
|
||||
print(f"num_turns: {num_turns}")
|
||||
for i in range(len(num_turns)):
|
||||
if i // n == 0:
|
||||
# [user, assistant]
|
||||
assert num_turns[i] == 2
|
||||
else:
|
||||
# [user, assistant, tool, assistant]
|
||||
assert num_turns[i] == 4
|
||||
|
||||
# Check response_mask
|
||||
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
||||
responses = result.batch["responses"]
|
||||
response_mask = result.batch["response_mask"]
|
||||
attention_mask = result.batch["attention_mask"]
|
||||
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
||||
response_length = response_mask.size(1)
|
||||
|
||||
for i in range(len(responses)):
|
||||
# response with tool response
|
||||
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
|
||||
response_with_obs = tokenizer.decode(valid_tokens)
|
||||
|
||||
# response without tool response
|
||||
valid_tokens = responses[i][response_mask[i].bool()]
|
||||
response_without_obs = tokenizer.decode(valid_tokens)
|
||||
|
||||
assert "<tool_response>" not in response_without_obs, (
|
||||
f"found <tool_response> in response: {response_without_obs}"
|
||||
)
|
||||
assert "</tool_response>" not in response_without_obs, (
|
||||
f"found </tool_response> in response: {response_without_obs}"
|
||||
)
|
||||
print("=========================")
|
||||
print(response_with_obs)
|
||||
print("---")
|
||||
print(response_without_obs)
|
||||
|
||||
print("Test passed!")
|
||||
ray.shutdown()
|
@ -28,7 +28,7 @@ from verl import DataProto
|
||||
from verl.utils.device import get_device_name
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
|
||||
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
|
||||
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
|
||||
|
||||
from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
|
||||
|
||||
@ -101,7 +101,9 @@ class DataParallelPRIMERewardModel:
|
||||
)
|
||||
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
||||
rm_log_labels = gather_outputs_and_unpad(
|
||||
rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
|
||||
)
|
||||
rm_log_labels = pad_input(
|
||||
hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
|
||||
).squeeze(-1)[:, -num_actions - 1 : -1]
|
||||
@ -149,7 +151,7 @@ class DataParallelPRIMERewardModel:
|
||||
logits=ref_output_logits, labels=input_ids_rmpad_rolled
|
||||
)
|
||||
|
||||
ref_log_labels = gather_outpus_and_unpad(
|
||||
ref_log_labels = gather_outputs_and_unpad(
|
||||
ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
|
||||
)
|
||||
ref_log_labels = pad_input(
|
||||
|
@ -409,7 +409,7 @@ class RewardModelWorker(Worker):
|
||||
def _forward_micro_batch(self, micro_batch):
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
|
||||
|
||||
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
|
||||
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):
|
||||
input_ids = micro_batch["input_ids"]
|
||||
@ -443,7 +443,7 @@ class RewardModelWorker(Worker):
|
||||
|
||||
# gather output if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
reward_rmpad = gather_outpus_and_unpad(
|
||||
reward_rmpad = gather_outputs_and_unpad(
|
||||
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
|
||||
)
|
||||
|
||||
|
@ -10,7 +10,7 @@ peft
|
||||
pyarrow>=15.0.0
|
||||
pybind11
|
||||
pylatexenc
|
||||
tensordict>=0.8.0,<=0.9.0
|
||||
tensordict>=0.8.0,<=0.9.1,!=0.9.0
|
||||
transformers==4.52.4
|
||||
ray==2.46.0
|
||||
wandb
|
||||
|
@ -14,7 +14,7 @@ pybind11
|
||||
pylatexenc
|
||||
pre-commit
|
||||
ray[default]
|
||||
tensordict>=0.8.0,<=0.9.0
|
||||
tensordict>=0.8.0,<=0.9.1,!=0.9.0
|
||||
torchdata
|
||||
transformers
|
||||
# vllm==0.8.4
|
||||
|
@ -12,7 +12,7 @@ pyarrow>=19.0.0
|
||||
pybind11
|
||||
pylatexenc
|
||||
ray[default]>=2.10
|
||||
tensordict>=0.8.0,<=0.9.0
|
||||
tensordict>=0.8.0,<=0.9.1,!=0.9.0
|
||||
torchdata
|
||||
torchvision
|
||||
transformers
|
||||
|
@ -17,9 +17,11 @@ import argparse
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, ContextManager
|
||||
from typing import Any, Callable, ContextManager, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import init_empty_weights
|
||||
from megatron.core import dist_checkpointing
|
||||
from megatron.core import parallel_state as mpu
|
||||
@ -29,11 +31,22 @@ from megatron.core.models.gpt.gpt_model import ModelType
|
||||
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
|
||||
from transformers import AutoConfig
|
||||
|
||||
from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards
|
||||
from verl.models.mcore import hf_to_mcore_config
|
||||
from verl.utils.device import get_device_name, get_torch_device
|
||||
from verl.utils.megatron_utils import get_model
|
||||
|
||||
|
||||
def _init_args():
|
||||
"""
|
||||
Examples:
|
||||
|
||||
1. single rank conversion for any model:
|
||||
> python converter_hf_to_mcore.py --hf_model_path %{hf_model} --output_path ${output_path}
|
||||
2. distributed conversion for DeepseekV3 671B:
|
||||
> torchrun --nproc_per_node 1 --nnodes 4 --node_rank ${RANK} converter_hf_to_mcore.py \
|
||||
--hf_model_path %{hf_model} --output_path ${output_path}
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model")
|
||||
@ -92,7 +105,17 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model):
|
||||
print("Conversion test passed!")
|
||||
|
||||
|
||||
def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config):
|
||||
@torch.inference_mode()
|
||||
def convert_checkpoint_from_transformers_to_megatron(
|
||||
hf_model, model, hf_config, layer_start_end: Optional[tuple[int, int]] = None
|
||||
):
|
||||
if layer_start_end is None:
|
||||
layer_start_end = (0, len(model.decoder.layers))
|
||||
layer_start, layer_end = layer_start_end
|
||||
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
numel = 0
|
||||
|
||||
num_attention_heads = hf_config.num_attention_heads
|
||||
num_key_value_heads = hf_config.num_key_value_heads
|
||||
hidden_dim = hf_config.hidden_size
|
||||
@ -101,50 +124,61 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
|
||||
print("[WARNING] Converting GQA model")
|
||||
has_qkv_bias = getattr(hf_config, "qkv_bias", False) or getattr(hf_config, "attention_bias", False)
|
||||
has_share_expert = getattr(hf_config, "shared_expert_intermediate_size", None)
|
||||
with torch.no_grad():
|
||||
model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
|
||||
for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers, strict=True):
|
||||
layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight)
|
||||
if pp_rank == 0:
|
||||
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
|
||||
|
||||
q = hf_layer.self_attn.q_proj.weight.view(
|
||||
[num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]
|
||||
assert len(model.decoder.layers) == (layer_end - layer_start), (
|
||||
f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}"
|
||||
)
|
||||
for layer_idx, (layer, hf_layer) in enumerate(
|
||||
zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)
|
||||
):
|
||||
global_layer_idx = layer_idx + layer_start
|
||||
numel_cur = numel
|
||||
numel += safe_copy(hf_layer.input_layernorm.weight, layer.self_attention.linear_qkv.layer_norm_weight)
|
||||
|
||||
q = hf_layer.self_attn.q_proj.weight.view(
|
||||
[num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]
|
||||
)
|
||||
k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1])
|
||||
v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1])
|
||||
qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()
|
||||
numel += safe_copy(qkv, layer.self_attention.linear_qkv.weight)
|
||||
|
||||
if has_qkv_bias:
|
||||
q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1])
|
||||
k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1])
|
||||
v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])
|
||||
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
|
||||
numel += safe_copy(qkv_bias, layer.self_attention.linear_qkv.bias)
|
||||
|
||||
if hasattr(hf_layer.self_attn, "q_norm"):
|
||||
numel += safe_copy(hf_layer.self_attn.q_norm.weight.data, layer.self_attention.q_layernorm.weight)
|
||||
numel += safe_copy(hf_layer.self_attn.k_norm.weight.data, layer.self_attention.k_layernorm.weight)
|
||||
|
||||
numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight)
|
||||
numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight)
|
||||
|
||||
numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight)
|
||||
|
||||
for idx, hf_expert in enumerate(hf_layer.mlp.experts):
|
||||
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
|
||||
numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"])
|
||||
numel += safe_copy(hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"])
|
||||
|
||||
if has_share_expert:
|
||||
numel += safe_copy(hf_layer.mlp.shared_expert_gate.weight, layer.mlp.shared_experts.gate_weight)
|
||||
shared_fc1_weight = torch.cat(
|
||||
[hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]
|
||||
)
|
||||
k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1])
|
||||
v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1])
|
||||
qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()
|
||||
layer.self_attention.linear_qkv.weight.copy_(qkv)
|
||||
numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)
|
||||
numel += safe_copy(hf_layer.mlp.shared_expert.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)
|
||||
print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}")
|
||||
|
||||
if has_qkv_bias:
|
||||
q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1])
|
||||
k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1])
|
||||
v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])
|
||||
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
|
||||
layer.self_attention.linear_qkv.bias.copy_(qkv_bias)
|
||||
|
||||
if hasattr(hf_layer.self_attn, "q_norm"):
|
||||
layer.self_attention.q_layernorm.weight.copy_(hf_layer.self_attn.q_norm.weight.data)
|
||||
layer.self_attention.k_layernorm.weight.copy_(hf_layer.self_attn.k_norm.weight.data)
|
||||
|
||||
layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight)
|
||||
layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight)
|
||||
|
||||
layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight)
|
||||
|
||||
for idx, hf_expert in enumerate(hf_layer.mlp.experts):
|
||||
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
|
||||
layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"].copy_(fc1_weight)
|
||||
layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"].copy_(hf_expert.down_proj.weight)
|
||||
|
||||
if has_share_expert:
|
||||
layer.mlp.shared_experts.gate_weight.copy_(hf_layer.mlp.shared_expert_gate.weight)
|
||||
shared_fc1_weight = torch.cat(
|
||||
[hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]
|
||||
)
|
||||
layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight)
|
||||
layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_expert.down_proj.weight)
|
||||
|
||||
model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight)
|
||||
model.output_layer.weight.copy_(hf_model.lm_head.weight)
|
||||
if pp_rank == pp_size - 1:
|
||||
numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)
|
||||
numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)
|
||||
return numel
|
||||
|
||||
|
||||
def safe_copy(
|
||||
@ -258,13 +292,31 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel
|
||||
assert n_params == copied_numel
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig):
|
||||
@torch.inference_mode()
|
||||
def convert_checkpoint_from_transformers_to_megatron_dpskv3(
|
||||
hf_model,
|
||||
model,
|
||||
hf_config,
|
||||
tfconfig,
|
||||
layer_start_end: Optional[tuple[int, int]] = None,
|
||||
):
|
||||
warnings.warn("MTP model is not supported yet", stacklevel=2)
|
||||
if layer_start_end is None:
|
||||
layer_start_end = (0, len(model.decoder.layers))
|
||||
layer_start, layer_end = layer_start_end
|
||||
numel: int = 0
|
||||
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
|
||||
print(f"{numel=}")
|
||||
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers, strict=True)):
|
||||
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
if pp_rank == 0:
|
||||
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
|
||||
|
||||
assert len(model.decoder.layers) == (layer_end - layer_start), (
|
||||
f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}"
|
||||
)
|
||||
for layer_idx, (layer, hf_layer) in enumerate(
|
||||
zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)
|
||||
):
|
||||
global_layer_idx = layer_idx + layer_start
|
||||
numel_cur: int = numel
|
||||
numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight)
|
||||
|
||||
@ -318,13 +370,14 @@ def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_
|
||||
)
|
||||
numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)
|
||||
numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)
|
||||
print(f"{layer_idx=} {numel=} numel this layer={numel - numel_cur}")
|
||||
print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}")
|
||||
assert numel - numel_cur == sum([i.numel() for i in hf_layer.state_dict().values()]), "numel mismatch"
|
||||
|
||||
numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)
|
||||
|
||||
if not hf_config.tie_word_embeddings:
|
||||
numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)
|
||||
print(f"{numel=}")
|
||||
if pp_rank == pp_size - 1:
|
||||
numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)
|
||||
if not hf_config.tie_word_embeddings:
|
||||
numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)
|
||||
print(f"{pp_rank=} {numel=}")
|
||||
return numel
|
||||
|
||||
|
||||
@ -333,6 +386,13 @@ def noop_context() -> Any:
|
||||
yield
|
||||
|
||||
|
||||
def support_distributed_convert(hf_config: AutoConfig) -> bool:
|
||||
for arch in ["DeepseekV3ForCausalLM", "Qwen3MoeForCausalLM", "Qwen2MoeForCausalLM"]:
|
||||
if arch in hf_config.architectures:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
if len(os.listdir(output_path)) > 0 and not test:
|
||||
@ -340,13 +400,22 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
|
||||
return
|
||||
|
||||
# init torch distributed and mpu
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
if "WORLD_SIZE" not in os.environ:
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
rank = dist.get_rank()
|
||||
local_rank = os.getenv("LOCAL_RANK", 0)
|
||||
world_size = dist.get_world_size()
|
||||
get_torch_device().set_device(f"{get_device_name()}:{local_rank}")
|
||||
|
||||
mpu.initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=world_size,
|
||||
virtual_pipeline_model_parallel_size=None,
|
||||
context_parallel_size=1,
|
||||
expert_model_parallel_size=1,
|
||||
@ -357,7 +426,18 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
|
||||
hf_config = AutoConfig.from_pretrained(hf_model_path)
|
||||
print(hf_config, flush=True)
|
||||
|
||||
tfconfig = hf_to_mcore_config(hf_config, torch.bfloat16)
|
||||
if world_size > 1 and not support_distributed_convert(hf_config):
|
||||
raise NotImplementedError(f"distributed conversion is not supported for {hf_config.architectures} yet.")
|
||||
|
||||
pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, world_size)
|
||||
print(f"Pipeline shards: {pipeline_shards}", flush=True)
|
||||
|
||||
tfconfig = hf_to_mcore_config(
|
||||
hf_config,
|
||||
torch.bfloat16,
|
||||
num_layers_in_first_pipeline_stage=pipeline_shards[0] if len(pipeline_shards) > 1 else None,
|
||||
num_layers_in_last_pipeline_stage=pipeline_shards[-1] if len(pipeline_shards) > 2 else None,
|
||||
)
|
||||
tfconfig.use_cpu_initialization = use_cpu_initialization
|
||||
tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
|
||||
|
||||
@ -403,17 +483,36 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
|
||||
)
|
||||
hf_state_dict = hf_model.state_dict()
|
||||
|
||||
# distributed convert
|
||||
if world_size > 1 and support_distributed_convert(hf_config):
|
||||
pipeline_cumsum = np.cumsum(pipeline_shards)
|
||||
layer_start = 0 if rank == 0 else pipeline_cumsum[rank - 1]
|
||||
layer_end = pipeline_cumsum[rank]
|
||||
if "DeepseekV3ForCausalLM" in hf_config.architectures:
|
||||
numel_partial: int = convert_checkpoint_from_transformers_to_megatron_dpskv3(
|
||||
hf_model, model[0].module, hf_config, tfconfig=tfconfig, layer_start_end=(layer_start, layer_end)
|
||||
)
|
||||
elif "Qwen3MoeForCausalLM" in hf_config.architectures or "Qwen2MoeForCausalLM" in hf_config.architectures:
|
||||
numel_partial: int = convert_checkpoint_from_transformers_to_megatron(
|
||||
hf_model, model[0].module, hf_config, layer_start_end=(layer_start, layer_end)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Distributed conversion is not supported for {hf_config.architectures} yet.")
|
||||
|
||||
numel_tensor = torch.tensor([numel_partial]).to(get_device_name())
|
||||
dist.all_reduce(numel_tensor, op=dist.ReduceOp.SUM)
|
||||
numel = int(numel_tensor.cpu().item())
|
||||
print(f"total numel={numel} vs {hf_model.num_parameters()=}")
|
||||
if numel != hf_model.num_parameters():
|
||||
warnings.warn(f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", stacklevel=1)
|
||||
|
||||
# load hf state dict to megatron model
|
||||
if "Qwen2MoeForCausalLM" in hf_config.architectures:
|
||||
elif "Qwen2MoeForCausalLM" in hf_config.architectures:
|
||||
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
|
||||
elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
|
||||
convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config)
|
||||
elif "DeepseekV3ForCausalLM" in hf_config.architectures:
|
||||
numel: int = convert_checkpoint_from_transformers_to_megatron_dpskv3(
|
||||
hf_model, model[0].module, hf_config, tfconfig=tfconfig
|
||||
)
|
||||
if numel != hf_model.num_parameters():
|
||||
warnings.warn(f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", stacklevel=1)
|
||||
convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)
|
||||
elif "Qwen3MoeForCausalLM" in hf_config.architectures:
|
||||
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
|
||||
else:
|
||||
|
6
setup.py
6
setup.py
@ -37,7 +37,7 @@ install_requires = [
|
||||
"pylatexenc",
|
||||
"ray[default]>=2.41.0",
|
||||
"torchdata",
|
||||
"tensordict>=0.8.0,<=0.9.0",
|
||||
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
|
||||
"transformers",
|
||||
"wandb",
|
||||
"packaging>=20.0",
|
||||
@ -48,9 +48,9 @@ PRIME_REQUIRES = ["pyext"]
|
||||
GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"]
|
||||
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
|
||||
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
|
||||
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.0", "vllm>=0.7.3,<=0.8.5"]
|
||||
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.8.5"]
|
||||
SGLANG_REQUIRES = [
|
||||
"tensordict>=0.8.0,<=0.9.0",
|
||||
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
|
||||
"sglang[srt,openai]==0.4.6.post5",
|
||||
"torch-memory-saver>=0.0.5",
|
||||
"torch==2.6.0",
|
||||
|
@ -109,6 +109,7 @@ class WeatherTool(BaseTool):
|
||||
Returns:
|
||||
the temperature, the location, and the unit in a dict
|
||||
"""
|
||||
print(f"[DEBUG] get_current_temperature: {location}, {unit}")
|
||||
return {
|
||||
"temperature": 26.1,
|
||||
"location": location,
|
||||
@ -143,6 +144,7 @@ class WeatherToolWithData(BaseTool):
|
||||
Returns:
|
||||
the temperature, the location, the date and the unit in a dict
|
||||
"""
|
||||
print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
|
||||
return {
|
||||
"temperature": 25.9,
|
||||
"location": location,
|
||||
@ -174,11 +176,11 @@ def test_tool_agent(init_config):
|
||||
tool_config = {
|
||||
"tools": [
|
||||
{
|
||||
"class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherTool",
|
||||
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool",
|
||||
"config": {"type": "native"},
|
||||
},
|
||||
{
|
||||
"class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherToolWithData",
|
||||
"class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData",
|
||||
"config": {"type": "native"},
|
||||
},
|
||||
]
|
||||
@ -238,15 +240,29 @@ def test_tool_agent(init_config):
|
||||
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
||||
responses = result.batch["responses"]
|
||||
response_mask = result.batch["response_mask"]
|
||||
attention_mask = result.batch["attention_mask"]
|
||||
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
||||
response_length = response_mask.size(1)
|
||||
|
||||
# Decode responses with response_mask
|
||||
for i in range(len(responses)):
|
||||
# response with tool response
|
||||
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
|
||||
response_with_obs = tokenizer.decode(valid_tokens)
|
||||
|
||||
# response without tool response
|
||||
valid_tokens = responses[i][response_mask[i].bool()]
|
||||
response_str = tokenizer.decode(valid_tokens)
|
||||
assert "<tool_response>" not in response_str, f"found <tool_response> in response: {response_str}"
|
||||
assert "</tool_response>" not in response_str, f"found </tool_response> in response: {response_str}"
|
||||
print(f"response: {response_str}")
|
||||
response_without_obs = tokenizer.decode(valid_tokens)
|
||||
|
||||
assert "<tool_response>" not in response_without_obs, (
|
||||
f"found <tool_response> in response: {response_without_obs}"
|
||||
)
|
||||
assert "</tool_response>" not in response_without_obs, (
|
||||
f"found </tool_response> in response: {response_without_obs}"
|
||||
)
|
||||
print("=========================")
|
||||
print(response_with_obs)
|
||||
print("---")
|
||||
print(response_without_obs)
|
||||
|
||||
print("Test passed!")
|
||||
ray.shutdown()
|
||||
|
@ -27,7 +27,7 @@ from verl.protocol import DataProto
|
||||
from verl.utils.distributed import initialize_global_process_group
|
||||
from verl.utils.model import compute_position_id_with_mask, create_random_mask
|
||||
from verl.utils.ulysses import (
|
||||
gather_outpus_and_unpad,
|
||||
gather_outputs_and_unpad,
|
||||
get_ulysses_sequence_parallel_world_size,
|
||||
set_ulysses_sequence_parallel_group,
|
||||
ulysses_pad_and_slice_inputs,
|
||||
@ -155,7 +155,7 @@ def _hf_casual_fwd(config, sp_size, dp_size):
|
||||
).logits # (1, total_nnz/n, vocab_size)
|
||||
|
||||
# all_gather output
|
||||
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
|
||||
# 2. perform normal forward
|
||||
set_ulysses_sequence_parallel_group(None)
|
||||
@ -234,7 +234,7 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size):
|
||||
).logits # (1, total_nnz/n, vocab_size)
|
||||
|
||||
# all_gather output
|
||||
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
|
||||
|
||||
# 2. perform normal forward
|
||||
set_ulysses_sequence_parallel_group(None)
|
||||
|
@ -175,6 +175,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
||||
actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \
|
||||
|
@ -12,6 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards
|
||||
from verl.utils.megatron.pipeline_parallel import make_batch_generator
|
||||
|
||||
|
||||
@ -45,3 +48,23 @@ def test_make_batch_generator_empty():
|
||||
assert len(generators) == vpp_size
|
||||
for gen in generators:
|
||||
assert list(gen) == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_num,pp_size,gt",
|
||||
[
|
||||
(61, 8, [6, 8, 8, 8, 8, 8, 8, 7]),
|
||||
(61, 7, [8, 9, 9, 9, 9, 9, 8]),
|
||||
(61, 1, [61]),
|
||||
(61, 0, ValueError),
|
||||
(10, 16, ValueError),
|
||||
],
|
||||
)
|
||||
def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt):
|
||||
if isinstance(gt, list):
|
||||
shards = get_dynamic_pipeline_shards(layer_num, pp_size)
|
||||
assert len(shards) == len(gt) == pp_size, f"Expected {pp_size} shards, got {len(shards)}"
|
||||
assert all([shard == gt[i] for i, shard in enumerate(shards)]), f"Expected shards {gt}, got {shards}"
|
||||
elif issubclass(gt, Exception):
|
||||
with pytest.raises(gt):
|
||||
shards = get_dynamic_pipeline_shards(layer_num, pp_size)
|
||||
|
@ -666,3 +666,27 @@ class Solution:
|
||||
assert "error" not in metadata_list[0]
|
||||
assert metadata_list[0].get("status") != "compilation error"
|
||||
assert metadata_list[0].get("status") != "runtime error"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_condition, reason=skip_reason)
|
||||
def test_none_and_empty_stdin_passed_correctly():
|
||||
"""
|
||||
Tests that when stdin data is set to an empty string or None, it is still
|
||||
is passed correctly to Sandbox Fusion as an empty string.
|
||||
"""
|
||||
echo_code = """
|
||||
import sys
|
||||
print(f"You said '{sys.stdin.readline().strip()}'")
|
||||
"""
|
||||
in_outs = {
|
||||
"inputs": [None, "", "hello"],
|
||||
"outputs": ["You said ''", "You said ''", "You said 'hello'"],
|
||||
}
|
||||
|
||||
# Use a short timeout for fast tests
|
||||
results, metadata_list = check_correctness(SANDBOX_URL, in_outs, echo_code, timeout=5)
|
||||
|
||||
assert results == [True, True, True]
|
||||
assert "error" not in metadata_list[0]
|
||||
assert metadata_list[0].get("status") != "compilation error"
|
||||
assert metadata_list[0].get("status") != "runtime error"
|
||||
|
@ -124,6 +124,60 @@ def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb):
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_dataproto_split_uneven():
|
||||
"""Test DataProto.split with uneven splits"""
|
||||
# Create test data with 10 items
|
||||
input_ids = torch.randint(low=0, high=10, size=(10, 5))
|
||||
attention_mask = torch.ones(10, 5)
|
||||
data = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
dataproto = DataProto.from_single_dict(data)
|
||||
|
||||
# Test split with size 3 (should create chunks of [3, 3, 3, 1])
|
||||
splits = dataproto.split(3)
|
||||
assert len(splits) == 4
|
||||
assert len(splits[0]) == 3
|
||||
assert len(splits[1]) == 3
|
||||
assert len(splits[2]) == 3
|
||||
assert len(splits[3]) == 1
|
||||
|
||||
reconstructed = DataProto.concat(splits)
|
||||
torch.testing.assert_close(reconstructed.batch["input_ids"], dataproto.batch["input_ids"])
|
||||
torch.testing.assert_close(reconstructed.batch["attention_mask"], dataproto.batch["attention_mask"])
|
||||
|
||||
# Test split with size equal to length (should create one chunk)
|
||||
splits = dataproto.split(10)
|
||||
assert len(splits) == 1
|
||||
assert len(splits[0]) == 10
|
||||
|
||||
# Test split with size larger than length (should create one chunk with all data)
|
||||
splits = dataproto.split(15)
|
||||
assert len(splits) == 1
|
||||
assert len(splits[0]) == 10
|
||||
|
||||
# Test with non-tensor batch data
|
||||
import numpy as np
|
||||
|
||||
data_with_non_tensor = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": np.array([f"label_{i}" for i in range(10)], dtype=object),
|
||||
}
|
||||
dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor)
|
||||
|
||||
splits = dataproto_with_non_tensor.split(3)
|
||||
assert len(splits) == 4
|
||||
assert len(splits[0]) == 3
|
||||
assert len(splits[1]) == 3
|
||||
assert len(splits[2]) == 3
|
||||
assert len(splits[3]) == 1
|
||||
|
||||
# Verify non-tensor data integrity
|
||||
reconstructed = DataProto.concat(splits)
|
||||
np.testing.assert_array_equal(
|
||||
reconstructed.non_tensor_batch["labels"], dataproto_with_non_tensor.non_tensor_batch["labels"]
|
||||
)
|
||||
|
||||
|
||||
def test_seqlen_balancing_distributed_params(tmp_path):
|
||||
world_size = 2
|
||||
init_file = tmp_path / "dist_init"
|
||||
|
@ -0,0 +1,57 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
|
||||
|
||||
_TENSOR_1MB = torch.zeros(512, 512)
|
||||
_BYTES_1MB = 1 << 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"named_tensors, bucket_size_mb, gt_groups",
|
||||
[
|
||||
(
|
||||
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
|
||||
0.5 * _BYTES_1MB,
|
||||
[["a"], ["b"]],
|
||||
),
|
||||
(
|
||||
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
|
||||
1 * _BYTES_1MB,
|
||||
[["a"], ["b"]],
|
||||
),
|
||||
(
|
||||
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
|
||||
1.5 * _BYTES_1MB,
|
||||
[["a"], ["b"]],
|
||||
),
|
||||
(
|
||||
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
|
||||
2 * _BYTES_1MB,
|
||||
[["a", "b"]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_named_tensor_buckets(named_tensors, bucket_size_mb, gt_groups: list[list[str]]):
|
||||
named_tensors_iter = iter(named_tensors)
|
||||
groups = list(get_named_tensor_buckets(named_tensors_iter, bucket_size_mb))
|
||||
assert len(groups) == len(gt_groups)
|
||||
for group, gt_group in zip(groups, gt_groups, strict=True):
|
||||
assert len(group) == len(gt_group)
|
||||
for (name, _), (gt_name) in zip(group, gt_group, strict=True):
|
||||
assert name == gt_name
|
@ -158,6 +158,8 @@ def get_rollout_config(
|
||||
"prompt_length": max_prompt_length,
|
||||
"response_length": max_response_length,
|
||||
"tensor_model_parallel_size": tensor_parallel_size,
|
||||
# set to 128MB only for testing
|
||||
"update_weights_bucket_megabytes": 128,
|
||||
"multi_turn": {
|
||||
"max_assistant_turns": 4,
|
||||
"max_user_turns": 4,
|
||||
|
@ -13,5 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .agent_loop import AgentLoopBase, AgentLoopManager
|
||||
from .single_turn_agent_loop import SingleTurnAgentLoop
|
||||
from .tool_agent_loop import ToolAgentLoop
|
||||
|
||||
_ = [SingleTurnAgentLoop, ToolAgentLoop]
|
||||
|
||||
__all__ = ["AgentLoopBase", "AgentLoopManager"]
|
||||
|
@ -19,11 +19,12 @@ import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import ray
|
||||
import torch
|
||||
from cachetools import LRUCache
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import BaseModel
|
||||
from tensordict import TensorDict
|
||||
from transformers import AutoTokenizer
|
||||
@ -120,29 +121,43 @@ class AgentLoopOutput(BaseModel):
|
||||
metrics: AgentLoopMetrics
|
||||
|
||||
|
||||
# make hydra.utils.instantiate happy
|
||||
class _DummyConfig:
|
||||
def __init__(self, config: DictConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
||||
class AgentLoopBase(ABC):
|
||||
"""An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various
|
||||
environments."""
|
||||
|
||||
_class_initialized = False
|
||||
|
||||
def __init__(self, config: DictConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer):
|
||||
"""Initialize agent loop.
|
||||
def __init__(
|
||||
self, trainer_config: _DummyConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer, **kwargs
|
||||
):
|
||||
"""Initialize agent loop, each sample will have its own loop instance.
|
||||
|
||||
Args:
|
||||
config (DictConfig): YAML config.
|
||||
trainer_config (_DummyConfig): trainer config.
|
||||
server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.
|
||||
tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
|
||||
"""
|
||||
self.config = config
|
||||
self.init_class(trainer_config.config, tokenizer, **kwargs)
|
||||
self.config = trainer_config.config
|
||||
self.server_manager = server_manager
|
||||
self.tokenizer = tokenizer
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.init_class(config, tokenizer)
|
||||
|
||||
@classmethod
|
||||
def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer):
|
||||
"""Initialize class state shared across all instances."""
|
||||
def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs):
|
||||
"""This is used to do heavy initialization work that should shared across all instances. It's only called once.
|
||||
|
||||
Args:
|
||||
config (DictConfig): trainer config.
|
||||
tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
|
||||
**kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`.
|
||||
"""
|
||||
if cls._class_initialized:
|
||||
return
|
||||
cls._class_initialized = True
|
||||
@ -161,6 +176,25 @@ class AgentLoopBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
"""Agent loop registry: key is agent_name, value is a dict of agent loop config
|
||||
used by hydra.utils.instantiate to initialize agent loop instance.
|
||||
|
||||
https://hydra.cc/docs/advanced/instantiate_objects/overview/
|
||||
"""
|
||||
_agent_loop_registry: dict[str, dict] = {}
|
||||
|
||||
|
||||
def register(agent_name: str):
|
||||
"""Register agent loop class."""
|
||||
|
||||
def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]:
|
||||
fqdn = f"{subclass.__module__}.{subclass.__qualname__}"
|
||||
_agent_loop_registry[agent_name] = {"_target_": fqdn}
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ray.remote
|
||||
class AgentLoopWorker:
|
||||
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
|
||||
@ -180,6 +214,13 @@ class AgentLoopWorker:
|
||||
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
||||
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
|
||||
|
||||
agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path
|
||||
if agent_loop_config_path:
|
||||
agent_loop_configs = OmegaConf.load(agent_loop_config_path)
|
||||
for agent_loop_config in agent_loop_configs:
|
||||
_agent_loop_registry[agent_loop_config.name] = agent_loop_config
|
||||
|
||||
trace_config = config.trainer.get("rollout_trace", {})
|
||||
trace_config = self.config.actor_rollout_ref.rollout.get("trace", {})
|
||||
RolloutTraceConfig.init(
|
||||
self.config.trainer.project_name,
|
||||
@ -260,36 +301,20 @@ class AgentLoopWorker:
|
||||
validate=trajectory["validate"],
|
||||
name="agent_loop",
|
||||
):
|
||||
agent_loop_class = self.get_agent_loop_class(agent_name)
|
||||
agent_loop = agent_loop_class(self.config, self.server_manager, self.tokenizer)
|
||||
assert agent_name in _agent_loop_registry, (
|
||||
f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}"
|
||||
)
|
||||
|
||||
agent_loop_config = _agent_loop_registry[agent_name]
|
||||
agent_loop = hydra.utils.instantiate(
|
||||
config=agent_loop_config,
|
||||
trainer_config=_DummyConfig(config=self.config),
|
||||
server_manager=self.server_manager,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
output = await agent_loop.run(messages, sampling_params)
|
||||
return output
|
||||
|
||||
def get_agent_loop_class(self, agent_name: str) -> type[AgentLoopBase]:
|
||||
"""Get the appropriate agent loop class based on agent name.
|
||||
|
||||
Factory method that returns the correct agent loop class implementation
|
||||
for the specified agent type.
|
||||
|
||||
Args:
|
||||
agent_name (str): Name of the agent type ('single_turn_agent' or 'tool_agent').
|
||||
|
||||
Returns:
|
||||
Type[AgentLoopBase]: Agent loop class corresponding to the agent name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the agent_name is not recognized.
|
||||
"""
|
||||
# TODO: add tool agent registrary
|
||||
from verl.experimental.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop
|
||||
from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop
|
||||
|
||||
if agent_name == "single_turn_agent":
|
||||
return SingleTurnAgentLoop
|
||||
elif agent_name == "tool_agent":
|
||||
return ToolAgentLoop
|
||||
raise ValueError(f"Unknown agent_name: {agent_name}")
|
||||
|
||||
def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:
|
||||
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
|
||||
# prompts: left pad
|
||||
|
@ -16,20 +16,21 @@ import os
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
|
||||
from verl.utils.profiler import simple_timer
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
@register("single_turn_agent")
|
||||
class SingleTurnAgentLoop(AgentLoopBase):
|
||||
"""Naive agent loop that only do single turn chat completion."""
|
||||
|
||||
def __init__(self, config, server_manager, tokenizer):
|
||||
super().__init__(config, server_manager, tokenizer)
|
||||
self.prompt_length = config.actor_rollout_ref.rollout.prompt_length
|
||||
self.response_length = config.actor_rollout_ref.rollout.response_length
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
|
||||
self.response_length = self.config.actor_rollout_ref.rollout.response_length
|
||||
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
metrics = {}
|
||||
|
@ -15,14 +15,11 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import regex as re
|
||||
from pydantic import BaseModel
|
||||
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
|
||||
from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser
|
||||
from verl.tools.utils.tool_registry import initialize_tools_from_config
|
||||
from verl.utils.profiler import simple_timer
|
||||
from verl.utils.rollout_trace import rollout_trace_op
|
||||
@ -31,68 +28,10 @@ logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
arguments: str
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class ToolParser(ABC):
|
||||
@abstractmethod
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
|
||||
"""Extract tool calls from the responses.
|
||||
|
||||
Args:
|
||||
responses_ids (List[int]): The ids of the responses.
|
||||
|
||||
Returns:
|
||||
List[FunctionCall]: The extracted tool calls.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HermesToolParser(ToolParser):
|
||||
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py"""
|
||||
|
||||
def __init__(self, tokenizer) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
|
||||
@rollout_trace_op
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
|
||||
loop = asyncio.get_running_loop()
|
||||
text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
|
||||
if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
|
||||
return []
|
||||
|
||||
matches = self.tool_call_regex.findall(text)
|
||||
function_calls = []
|
||||
for match in matches:
|
||||
try:
|
||||
function_call = json.loads(match)
|
||||
name, arguments = function_call["name"], function_call["arguments"]
|
||||
function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode tool call: {e}")
|
||||
return function_calls
|
||||
|
||||
|
||||
@register("tool_agent")
|
||||
class ToolAgentLoop(AgentLoopBase):
|
||||
def __init__(self, config, server_manager, tokenizer):
|
||||
super().__init__(config, server_manager, tokenizer)
|
||||
|
||||
@classmethod
|
||||
def init_class(cls, config, tokenizer):
|
||||
def init_class(cls, config, tokenizer, **kwargs):
|
||||
if cls._class_initialized:
|
||||
return
|
||||
cls._class_initialized = True
|
||||
@ -109,7 +48,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
|
||||
cls.tools = {tool.name: tool for tool in tool_list}
|
||||
cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
|
||||
cls.tool_parser = cls.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format)
|
||||
cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer)
|
||||
print(f"Initialized tools: {cls.tools}")
|
||||
|
||||
cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
|
||||
@ -151,7 +90,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
break
|
||||
|
||||
# no tool calls
|
||||
tool_calls = await self.tool_parser.extract_tool_calls(response_ids)
|
||||
_, tool_calls = await self.tool_parser.extract_tool_calls(response_ids)
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
@ -225,12 +164,3 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
"role": "tool",
|
||||
"content": tool_response,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name: str) -> ToolParser:
|
||||
tool_parsers = {
|
||||
"hermes": HermesToolParser,
|
||||
}
|
||||
if name not in tool_parsers:
|
||||
raise ValueError(f"Unknown tool parser: {name}")
|
||||
return tool_parsers[name](cls.tokenizer)
|
||||
|
106
verl/experimental/agent_loop/tool_parser.py
Normal file
106
verl/experimental/agent_loop/tool_parser.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import regex as re
|
||||
from pydantic import BaseModel
|
||||
|
||||
from verl.utils.rollout_trace import rollout_trace_op
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
arguments: str
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class ToolParser(ABC):
|
||||
_registry: dict[str, type["ToolParser"]] = {}
|
||||
|
||||
def __init__(self, tokenizer) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@abstractmethod
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
|
||||
"""Extract tool calls from the responses.
|
||||
|
||||
Args:
|
||||
responses_ids (List[int]): The ids of the responses.
|
||||
|
||||
Returns:
|
||||
Tuple[str, List[FunctionCall]]: Content and extracted tool calls.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name: str, tokenizer):
|
||||
if name not in cls._registry:
|
||||
raise ValueError(f"Unknown tool parser: {name}")
|
||||
return cls._registry[name](tokenizer)
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def decorator(subclass: type[ToolParser]) -> type[ToolParser]:
|
||||
cls._registry[name] = subclass
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ToolParser.register("hermes")
|
||||
class HermesToolParser(ToolParser):
|
||||
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py"""
|
||||
|
||||
def __init__(self, tokenizer) -> None:
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
|
||||
@rollout_trace_op
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
|
||||
if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
|
||||
return text, []
|
||||
|
||||
matches = self.tool_call_regex.findall(text)
|
||||
function_calls = []
|
||||
for match in matches:
|
||||
try:
|
||||
function_call = json.loads(match)
|
||||
name, arguments = function_call["name"], function_call["arguments"]
|
||||
function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode tool call: {e}")
|
||||
|
||||
# remaing text exclude tool call tokens
|
||||
content = self.tool_call_regex.sub("", text)
|
||||
|
||||
return content, function_calls
|
13
verl/experimental/dataset/__init__.py
Normal file
13
verl/experimental/dataset/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
13
verl/experimental/dynamic_dataset/__init__.py
Normal file
13
verl/experimental/dynamic_dataset/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -32,6 +32,16 @@ python -m verl.model_merger merge \
|
||||
--target_dir /path/to/merged_hf_model
|
||||
```
|
||||
|
||||
or use distribtued merge for large models like dpskv3 671B
|
||||
|
||||
```sh
|
||||
torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\
|
||||
--backend megatron \
|
||||
--local_dir ./checkpoints/global_step_1/actor \
|
||||
--target_dir /path/to/merged_hf_model
|
||||
```
|
||||
|
||||
|
||||
For more details, please refer to documentation:
|
||||
https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model
|
||||
"""
|
||||
|
@ -45,6 +45,7 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to tie word embedding weights (currently only Megatron supported)",
|
||||
)
|
||||
base_op_parser.add_argument("--trust-remote-code", action="store_true", help="Whether to trust remote code")
|
||||
base_op_parser.add_argument(
|
||||
"--is-value-model",
|
||||
action="store_true",
|
||||
@ -88,6 +89,7 @@ class ModelMergerConfig:
|
||||
private: bool = False
|
||||
test_hf_dir: Optional[str] = None
|
||||
tie_word_embedding: bool = False
|
||||
trust_remote_code: bool = False
|
||||
is_value_model: bool = False
|
||||
local_dir: Optional[str] = None
|
||||
hf_model_config_path: Optional[str] = None
|
||||
@ -107,6 +109,7 @@ def generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig:
|
||||
"operation": args.operation,
|
||||
"backend": args.backend,
|
||||
"tie_word_embedding": args.tie_word_embedding,
|
||||
"trust_remote_code": args.trust_remote_code,
|
||||
"is_value_model": args.is_value_model,
|
||||
"local_dir": args.local_dir,
|
||||
"hf_model_config_path": os.path.join(args.local_dir, "huggingface"),
|
||||
@ -161,7 +164,9 @@ class BaseModelMerger(ABC):
|
||||
def __init__(self, config: ModelMergerConfig):
|
||||
self.config = config
|
||||
self.hf_model_config_path = config.hf_model_config_path
|
||||
self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code
|
||||
)
|
||||
|
||||
def get_transformers_auto_model_class(self):
|
||||
if "ForTokenClassification" in self.model_config.architectures[0]:
|
||||
@ -250,7 +255,9 @@ class BaseModelMerger(ABC):
|
||||
def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]):
|
||||
auto_model_class = self.get_transformers_auto_model_class()
|
||||
with init_empty_weights():
|
||||
model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16)
|
||||
model = auto_model_class.from_config(
|
||||
self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code
|
||||
)
|
||||
model.to_empty(device="cpu")
|
||||
model = self.patch_model_generation_config(model)
|
||||
|
||||
@ -263,8 +270,8 @@ class BaseModelMerger(ABC):
|
||||
del state_dict
|
||||
del model
|
||||
|
||||
processor = hf_processor(self.hf_model_config_path)
|
||||
tokenizer = hf_tokenizer(self.hf_model_config_path)
|
||||
processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
|
||||
tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
|
||||
if processor is not None:
|
||||
print(f"Saving processor to {self.config.target_dir}")
|
||||
processor.save_pretrained(self.config.target_dir)
|
||||
|
@ -12,13 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ContextManager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import init_empty_weights
|
||||
from megatron.core import mpu
|
||||
from megatron.core.models.gpt.gpt_model import ModelType
|
||||
@ -30,9 +33,10 @@ from transformers import (
|
||||
)
|
||||
|
||||
from verl.models.mcore import hf_to_mcore_config
|
||||
from verl.utils.device import get_nccl_backend
|
||||
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
|
||||
from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing
|
||||
from verl.utils.megatron_utils import get_model
|
||||
from verl.utils.tokenizer import hf_processor, hf_tokenizer
|
||||
|
||||
from .base_model_merger import BaseModelMerger, ModelMergerConfig
|
||||
|
||||
@ -42,6 +46,50 @@ def noop_context() -> Any:
|
||||
yield
|
||||
|
||||
|
||||
def get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]:
|
||||
"""Calculate the pipeline sharding configuration for Megatron-LM.
|
||||
|
||||
Args:
|
||||
layer_num: Total number of layers in the model.
|
||||
pp_size: Number of pipeline parallel ranks.
|
||||
|
||||
Returns:
|
||||
layer number of each pp rank. Make the sharding of the pipeline as uniform as possible.
|
||||
"""
|
||||
if layer_num < pp_size:
|
||||
raise ValueError(f"layer_num {layer_num} must be greater than pp_size {pp_size}.")
|
||||
|
||||
if pp_size < 1:
|
||||
raise ValueError(f"pp_size must be at least 1, got {pp_size}.")
|
||||
if pp_size == 1:
|
||||
return [layer_num]
|
||||
|
||||
if pp_size == 2:
|
||||
return [
|
||||
layer_num // 2,
|
||||
layer_num - layer_num // 2,
|
||||
]
|
||||
|
||||
middle_size = pp_size - 2
|
||||
shards_strategy = []
|
||||
for middle_layer_num in range(layer_num):
|
||||
first_last_layer_num = layer_num - middle_layer_num * middle_size
|
||||
first_layer_num = first_last_layer_num // 2
|
||||
last_layer_num = first_last_layer_num - first_last_layer_num // 2
|
||||
if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num:
|
||||
shards_strategy.append(
|
||||
(
|
||||
[first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num],
|
||||
abs(first_layer_num - middle_layer_num),
|
||||
)
|
||||
)
|
||||
|
||||
# sort by diff of layer_num, to make it as uniform as possible
|
||||
res = sorted(shards_strategy, key=lambda x: x[1])[0][0]
|
||||
assert sum(res) == layer_num, f"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}"
|
||||
return res
|
||||
|
||||
|
||||
class MegatronModelMerger(BaseModelMerger):
|
||||
"""
|
||||
Model merger for Megatron-LM distributed checkpoints.
|
||||
@ -87,19 +135,31 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
def __init__(self, config: ModelMergerConfig):
|
||||
super().__init__(config)
|
||||
# Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
if "WORLD_SIZE" not in os.environ:
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
torch.distributed.init_process_group(get_nccl_backend())
|
||||
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
local_rank = os.environ.get("LOCAL_RANK", 0)
|
||||
get_torch_device().set_device(f"{get_device_name()}:{local_rank}")
|
||||
|
||||
mpu.initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=self.world_size,
|
||||
virtual_pipeline_model_parallel_size=None,
|
||||
context_parallel_size=1,
|
||||
expert_model_parallel_size=1,
|
||||
)
|
||||
model_parallel_cuda_manual_seed(0)
|
||||
self.hf_config = AutoConfig.from_pretrained(self.config.hf_model_config_path)
|
||||
self.hf_config = AutoConfig.from_pretrained(
|
||||
self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code
|
||||
)
|
||||
print(self.hf_config, flush=True)
|
||||
|
||||
self.params_mapping = {
|
||||
@ -107,6 +167,9 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
# NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the
|
||||
# longer key within the containing relationship is processed first.
|
||||
"embedding.word_embeddings": "model.embed_tokens",
|
||||
# input layer norm for dpskv3
|
||||
"input_layernorm.weight": "input_layernorm.weight",
|
||||
"input_layernorm.bias": "input_layernorm.bias",
|
||||
# attn
|
||||
"self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight",
|
||||
"self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias",
|
||||
@ -140,6 +203,11 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
"output_layer": "lm_head",
|
||||
}
|
||||
|
||||
if "Qwen2MoeForCausalLM" in self.hf_config.architectures:
|
||||
self.params_mapping["mlp.shared_experts.linear_fc1"] = "mlp.shared_expert.gate_up_proj"
|
||||
self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj"
|
||||
self.params_mapping["mlp.shared_experts.gate_weight"] = "mlp.shared_expert_gate.weight"
|
||||
|
||||
def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]:
|
||||
"""_summary_
|
||||
Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory.
|
||||
@ -152,7 +220,15 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
"""
|
||||
|
||||
# init hf config
|
||||
tf_config = hf_to_mcore_config(self.hf_config, torch.bfloat16)
|
||||
self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size)
|
||||
print(f"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}")
|
||||
|
||||
tf_config = hf_to_mcore_config(
|
||||
self.hf_config,
|
||||
torch.bfloat16,
|
||||
num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None,
|
||||
num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None,
|
||||
)
|
||||
tf_config.use_cpu_initialization = self.config.use_cpu_initialization
|
||||
tie_word_embeddings = getattr(self.hf_config, "tie_word_embeddings", False)
|
||||
|
||||
@ -273,7 +349,11 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
layers_cum = 0
|
||||
if self.world_size > 1:
|
||||
pipeline_cumsum = np.cumsum(self.pipeline_shards)
|
||||
layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]
|
||||
|
||||
print(f"{layers_cum=}")
|
||||
for model_state_dict in model_state_dict_list:
|
||||
layers_handled = 0
|
||||
keys = model_state_dict.keys()
|
||||
@ -297,6 +377,15 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
else:
|
||||
warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2)
|
||||
|
||||
if "mlp.experts." in hf_name and ".weight" in hf_name:
|
||||
name_prefix, expert_id = hf_name.split(".weight")
|
||||
for proj in ["gate_up", "down"]:
|
||||
if f"{proj}_proj" in hf_name:
|
||||
hf_name = hf_name.replace(
|
||||
f"mlp.experts.{proj}_proj.weight{expert_id}",
|
||||
f"mlp.experts.{expert_id}.{proj}_proj.weight",
|
||||
)
|
||||
|
||||
tensor = model_state_dict[key]
|
||||
split_tensor = self._split_tensors(
|
||||
key, tensor, self.hf_config, is_value_model=self.config.is_value_model
|
||||
@ -321,6 +410,75 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
|
||||
return state_dict
|
||||
|
||||
def save_hf_model_and_tokenizer(self, merged_state_dict):
|
||||
if self.world_size == 1:
|
||||
return super().save_hf_model_and_tokenizer(merged_state_dict)
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
layer_num = self.hf_config.num_hidden_layers
|
||||
|
||||
# FIXME: make configurable
|
||||
saves_per_layer = 1 if layer_num < 30 else 2
|
||||
saves_total = saves_per_layer * layer_num
|
||||
saves_indexes = {}
|
||||
|
||||
# calculate the layer start index and key chunks
|
||||
layer_this_rank = self.pipeline_shards[self.rank]
|
||||
pipeline_cumsum = np.cumsum(self.pipeline_shards)
|
||||
layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]
|
||||
keys = list(merged_state_dict.keys())
|
||||
keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer)
|
||||
numel = 0
|
||||
|
||||
assert len(keys_chunk) == layer_this_rank * saves_per_layer, (
|
||||
f"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}."
|
||||
)
|
||||
|
||||
# save to model shards manually
|
||||
target_dir = Path(self.config.target_dir)
|
||||
for i, keys in enumerate(keys_chunk):
|
||||
sd_to_save = {k: merged_state_dict[k] for k in keys}
|
||||
numel += sum([sd_to_save[i].numel() for i in sd_to_save])
|
||||
save_idx = layer_start * saves_per_layer + i
|
||||
save_path = target_dir / f"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors"
|
||||
|
||||
save_file(sd_to_save, save_path)
|
||||
for k in keys:
|
||||
saves_indexes[k] = str(save_path.name)
|
||||
|
||||
tensor = torch.tensor([numel]).to(get_device_name())
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
numel = tensor.cpu().item()
|
||||
|
||||
all_save_indexes = [{} for _ in range(self.world_size)]
|
||||
dist.all_gather_object(all_save_indexes, saves_indexes)
|
||||
saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()}
|
||||
if self.rank == 0:
|
||||
with open(target_dir / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"metadata": {
|
||||
"total_size": numel,
|
||||
},
|
||||
"weight_map": saves_indexes,
|
||||
},
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
print(f"model saved to {target_dir} with {numel=}")
|
||||
|
||||
self.model_config.save_pretrained(self.config.target_dir)
|
||||
|
||||
processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
|
||||
tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)
|
||||
if processor is not None:
|
||||
print(f"Saving processor to {self.config.target_dir}")
|
||||
processor.save_pretrained(self.config.target_dir)
|
||||
if tokenizer is not None:
|
||||
print(f"Saving tokenizer to {self.config.target_dir}")
|
||||
tokenizer.save_pretrained(self.config.target_dir)
|
||||
|
||||
def merge_and_save(self):
|
||||
from verl.utils.megatron_utils import get_dist_checkpoint_path
|
||||
|
||||
@ -370,6 +528,7 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
|
||||
megatron_name = megatron_name.replace("decoder", "model")
|
||||
param_name = megatron_name.replace(m_name, v_name)
|
||||
|
||||
return param_name
|
||||
|
||||
return None # Return None if no mapping found
|
||||
|
@ -736,11 +736,7 @@ class DataProto:
|
||||
Returns:
|
||||
List[DataProto]: a list of DataProto after splitting
|
||||
"""
|
||||
assert len(self) % split_size == 0, (
|
||||
f"only support equal split. Got size of DataProto {len(self)} and chunk {split_size}."
|
||||
)
|
||||
chunks = len(self) // split_size
|
||||
return self.chunk(chunks)
|
||||
return [self[i : i + split_size] for i in range(0, len(self), split_size)]
|
||||
|
||||
@staticmethod
|
||||
def concat(data: list["DataProto"]) -> "DataProto":
|
||||
|
@ -57,7 +57,7 @@ class WorkerHelper:
|
||||
return sock.getsockname()[1]
|
||||
|
||||
def get_availale_master_addr_port(self):
|
||||
return self._get_node_ip(), str(self._get_free_port())
|
||||
return self._get_node_ip().strip("[]"), str(self._get_free_port())
|
||||
|
||||
|
||||
# we assume that in each WorkerGroup, there is a Master Worker
|
||||
|
@ -127,9 +127,11 @@ actor_rollout_ref:
|
||||
calculate_log_probs: false
|
||||
agent:
|
||||
num_workers: 8
|
||||
agent_loop_config_path: null
|
||||
custom_async_server:
|
||||
path: null
|
||||
name: null
|
||||
update_weights_bucket_megabytes: 512
|
||||
trace:
|
||||
backend: null
|
||||
token2text: false
|
||||
|
@ -170,6 +170,18 @@ agent:
|
||||
# Number of agent loop workers
|
||||
num_workers: 8
|
||||
|
||||
# custom agent loop config path, which should contain list of configs to intialize AgentLoop instances.
|
||||
# https://hydra.cc/docs/advanced/instantiate_objects/overview/
|
||||
#
|
||||
# - name: react_agent
|
||||
# _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop
|
||||
# tools: ["get_current_temperature"]
|
||||
# - name: math_expression
|
||||
# _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
|
||||
# min_terms: 2
|
||||
# max_terms: 6
|
||||
agent_loop_config_path: null
|
||||
|
||||
# custom async server configs
|
||||
custom_async_server:
|
||||
|
||||
@ -179,6 +191,20 @@ agent:
|
||||
# Class name of the custom async server class (e.g. AsyncvLLMServer)
|
||||
name: null
|
||||
|
||||
# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations.
|
||||
# This parameter controls the maximum payload size for a single weight update request.
|
||||
# Reference: https://github.com/volcengine/verl/pull/2418
|
||||
# Currently only supported in SGLang rollout implementations
|
||||
# Larger values may improve throughput but increase memory overhead
|
||||
# Detailed performance comparison:
|
||||
# https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720
|
||||
# Default value (512MB) is optimized for typical GPU memory configurations
|
||||
# For the best performance of `rebuild_cuda_tensor`, it is recommended to:
|
||||
# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`
|
||||
# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`
|
||||
# when using Tensor Parallelism (TP) >= 8.
|
||||
update_weights_bucket_megabytes: 512
|
||||
|
||||
# trace rollout data
|
||||
trace:
|
||||
|
||||
|
@ -62,7 +62,7 @@ from verl.utils.torch_dtypes import PrecisionType
|
||||
from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
|
||||
from verl.utils.tracking import Tracking
|
||||
from verl.utils.ulysses import (
|
||||
gather_outpus_and_unpad,
|
||||
gather_outputs_and_unpad,
|
||||
get_ulysses_sequence_parallel_world_size,
|
||||
ulysses_pad_and_slice_inputs,
|
||||
)
|
||||
@ -406,7 +406,7 @@ class FSDPSFTTrainer:
|
||||
input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)
|
||||
loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)
|
||||
# Gather and unpad for sequence parallelism
|
||||
loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
||||
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
||||
|
||||
# This is the loss collected from all ulysses ranks
|
||||
full_loss = pad_input(
|
||||
|
@ -64,8 +64,8 @@ def run_ppo(config) -> None:
|
||||
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
|
||||
if (
|
||||
is_cuda_available
|
||||
and OmegaConf.select(config.trainer, "profile_steps") is not None
|
||||
and len(OmegaConf.select(config.trainer, "profile_steps")) > 0
|
||||
and config.trainer.get("profile_steps") is not None
|
||||
and len(config.trainer.get("profile_steps", [])) > 0
|
||||
):
|
||||
nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)
|
||||
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
|
||||
@ -106,9 +106,7 @@ class TaskRunner:
|
||||
from verl.utils.fs import copy_to_local
|
||||
|
||||
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
|
||||
|
||||
pprint(OmegaConf.to_container(config, resolve=True))
|
||||
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# Download the checkpoint from HDFS to the local machine.
|
||||
@ -125,14 +123,6 @@ class TaskRunner:
|
||||
# Used for multimodal LLM, could be None
|
||||
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
|
||||
|
||||
# Version validation for vllm.
|
||||
if config.actor_rollout_ref.rollout.name in ["vllm"]:
|
||||
from verl.utils.vllm_utils import is_version_ge
|
||||
|
||||
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0:
|
||||
if not is_version_ge(pkg="vllm", minver="0.7.3"):
|
||||
raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3")
|
||||
|
||||
# Define worker classes based on the actor strategy.
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
||||
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
FSDP PPO Trainer with Ray-based single controller.
|
||||
PPO Trainer with Ray-based single controller.
|
||||
This trainer supports model-agonistic model initialization with huggingface
|
||||
"""
|
||||
|
||||
@ -550,16 +550,6 @@ class RayPPOTrainer:
|
||||
"validation gen temperature should be greater than 0 when enabling do_sample"
|
||||
)
|
||||
|
||||
# check multi_turn with tool config
|
||||
if config.actor_rollout_ref.rollout.multi_turn.enable:
|
||||
assert (
|
||||
config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None
|
||||
or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path is not None
|
||||
), (
|
||||
"tool_config_path or interaction_config_path must be set when enabling multi_turn with tool, "
|
||||
"due to no role-playing support"
|
||||
)
|
||||
|
||||
print("[validate_config] All configuration checks passed successfully!")
|
||||
|
||||
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
|
||||
@ -1049,6 +1039,28 @@ class RayPPOTrainer:
|
||||
else:
|
||||
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
|
||||
|
||||
def _start_profiling(self, do_profile: bool) -> None:
|
||||
"""Start profiling for all worker groups if profiling is enabled."""
|
||||
if do_profile:
|
||||
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
|
||||
if self.use_reference_policy:
|
||||
self.ref_policy_wg.start_profile()
|
||||
if self.use_critic:
|
||||
self.critic_wg.start_profile()
|
||||
if self.use_rm:
|
||||
self.rm_wg.start_profile()
|
||||
|
||||
def _stop_profiling(self, do_profile: bool) -> None:
|
||||
"""Stop profiling for all worker groups if profiling is enabled."""
|
||||
if do_profile:
|
||||
self.actor_rollout_wg.stop_profile()
|
||||
if self.use_reference_policy:
|
||||
self.ref_policy_wg.stop_profile()
|
||||
if self.use_critic:
|
||||
self.critic_wg.stop_profile()
|
||||
if self.use_rm:
|
||||
self.rm_wg.stop_profile()
|
||||
|
||||
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
|
||||
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
|
||||
attention_mask = batch.batch["attention_mask"]
|
||||
@ -1118,14 +1130,7 @@ class RayPPOTrainer:
|
||||
else False
|
||||
)
|
||||
with marked_timer("start_profile", timing_raw):
|
||||
if do_profile:
|
||||
self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps)
|
||||
if self.use_reference_policy:
|
||||
self.ref_policy_wg.start_profile()
|
||||
if self.use_critic:
|
||||
self.critic_wg.start_profile()
|
||||
if self.use_rm:
|
||||
self.rm_wg.start_profile()
|
||||
self._start_profiling(do_profile)
|
||||
|
||||
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
||||
|
||||
@ -1319,7 +1324,6 @@ class RayPPOTrainer:
|
||||
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
|
||||
if rollout_data_dir:
|
||||
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
|
||||
print(batch.batch.keys())
|
||||
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
|
||||
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
|
||||
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
|
||||
@ -1366,14 +1370,7 @@ class RayPPOTrainer:
|
||||
self._save_checkpoint()
|
||||
|
||||
with marked_timer("stop_profile", timing_raw):
|
||||
if do_profile:
|
||||
self.actor_rollout_wg.stop_profile()
|
||||
if self.use_reference_policy:
|
||||
self.ref_policy_wg.stop_profile()
|
||||
if self.use_critic:
|
||||
self.critic_wg.stop_profile()
|
||||
if self.use_rm:
|
||||
self.rm_wg.stop_profile()
|
||||
self._stop_profiling(do_profile)
|
||||
|
||||
steps_duration = timing_raw["step"]
|
||||
self.max_steps_duration = max(self.max_steps_duration, steps_duration)
|
||||
|
@ -14,10 +14,18 @@
|
||||
|
||||
import re
|
||||
|
||||
_SOLUTION_CLIP_CHARS = 300
|
||||
|
||||
|
||||
def extract_solution(solution_str, method="strict"):
|
||||
assert method in ["strict", "flexible"]
|
||||
|
||||
# Optimization: Regular expression matching on very long strings can be slow.
|
||||
# For math problems, the final answer is usually at the end.
|
||||
# We only match on the last 300 characters, which is a safe approximation for 300 tokens.
|
||||
if len(solution_str) > _SOLUTION_CLIP_CHARS:
|
||||
solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]
|
||||
|
||||
if method == "strict":
|
||||
# this also tests the formatting of the model
|
||||
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
|
@ -67,7 +67,7 @@ SUPPORTED_LANGUAGES = [
|
||||
def call_sandbox_api(
|
||||
sandbox_fusion_url: str,
|
||||
code: str,
|
||||
stdin: str,
|
||||
stdin: Optional[str],
|
||||
compile_timeout: int,
|
||||
run_timeout: int,
|
||||
memory_limit_mb: int,
|
||||
@ -259,9 +259,9 @@ def _execute_user_function():
|
||||
# Attempt to instantiate and get method.
|
||||
# Errors (e.g., Solution not a class, instantiation fails, method missing)
|
||||
# will be caught by the broad except block below.
|
||||
_solution_instance = _Solution_class()
|
||||
_solution_instance = _Solution_class()
|
||||
_target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME)
|
||||
|
||||
|
||||
if not _target_callable:
|
||||
sys.stderr.write(f"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\n")
|
||||
return None, True # result, error_occurred
|
||||
@ -286,10 +286,11 @@ if __name__ == '__main__':
|
||||
print(str(_result))
|
||||
# Optional: To explicitly exit with an error code if the sandbox relies on it
|
||||
# else:
|
||||
# sys.exit(1)
|
||||
# sys.exit(1)
|
||||
"""
|
||||
current_generation_code = wrapper_code
|
||||
|
||||
stdin = None if stdin_data is None else str(stdin_data)
|
||||
try:
|
||||
if concurrent_semaphore:
|
||||
# logger.debug(f"Case {case_index + 1}: Attempting to acquire semaphore.")
|
||||
@ -298,7 +299,7 @@ if __name__ == '__main__':
|
||||
api_response, error_msg = call_sandbox_api(
|
||||
sandbox_fusion_url=sandbox_fusion_url,
|
||||
code=current_generation_code,
|
||||
stdin=str(stdin_data),
|
||||
stdin=stdin,
|
||||
compile_timeout=timeout,
|
||||
run_timeout=timeout,
|
||||
memory_limit_mb=memory_limit_mb,
|
||||
@ -309,7 +310,7 @@ if __name__ == '__main__':
|
||||
api_response, error_msg = call_sandbox_api(
|
||||
sandbox_fusion_url=sandbox_fusion_url,
|
||||
code=current_generation_code,
|
||||
stdin=str(stdin_data),
|
||||
stdin=stdin,
|
||||
compile_timeout=timeout,
|
||||
run_timeout=timeout,
|
||||
memory_limit_mb=memory_limit_mb,
|
||||
@ -322,7 +323,7 @@ if __name__ == '__main__':
|
||||
|
||||
metadata = {
|
||||
"case_index": case_index,
|
||||
"input": str(stdin_data),
|
||||
"input": stdin,
|
||||
"expected_output": str(expected_output),
|
||||
"api_request_error": error_msg,
|
||||
"api_response": None,
|
||||
@ -346,7 +347,7 @@ if __name__ == '__main__':
|
||||
# Log code and input only on error for brevity
|
||||
generation_to_log = generation[:200] + "..." if len(generation) > 200 else generation
|
||||
logger.error(f"Case {case_index}: code: {generation_to_log}")
|
||||
logger.error(f"Case {case_index}: input: {str(stdin_data)}")
|
||||
logger.error(f"Case {case_index}: input: {stdin}")
|
||||
elif api_response:
|
||||
# --- Add debug logging ---
|
||||
logger.debug(f"Case {case_index}: API Response: {api_response}")
|
||||
|
@ -234,7 +234,13 @@ class Gather(torch.autograd.Function):
|
||||
)
|
||||
|
||||
|
||||
def gather_outpus_and_unpad(
|
||||
def gather_outpus_and_unpad(*args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"please use verl.utils.ulysses.gather_outputs_and_unpad instead of verl.utils.ulysses.gather_outpus_and_unpad"
|
||||
)
|
||||
|
||||
|
||||
def gather_outputs_and_unpad(
|
||||
x: Tensor,
|
||||
gather_dim: int,
|
||||
unpad_dim: int = None,
|
||||
|
@ -33,7 +33,7 @@ from verl.utils.profiler import GPUMemoryLogger
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
|
||||
from verl.utils.torch_functional import logprobs_from_logits
|
||||
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
|
||||
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
|
||||
from verl.workers.actor import BasePPOActor
|
||||
|
||||
if is_cuda_available:
|
||||
@ -203,14 +203,14 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
# gather log_prob if sp > 1
|
||||
if self.use_ulysses_sp:
|
||||
# gather and unpad for the ulysses sp
|
||||
log_probs = gather_outpus_and_unpad(
|
||||
log_probs = gather_outputs_and_unpad(
|
||||
log_probs,
|
||||
gather_dim=0,
|
||||
unpad_dim=0,
|
||||
padding_size=pad_size,
|
||||
)
|
||||
if calculate_entropy:
|
||||
entropy_rmpad = gather_outpus_and_unpad(
|
||||
entropy_rmpad = gather_outputs_and_unpad(
|
||||
entropy_rmpad,
|
||||
gather_dim=0,
|
||||
unpad_dim=0,
|
||||
|
@ -295,7 +295,15 @@ class MegatronPPOActor(BasePPOActor):
|
||||
Returns:
|
||||
|
||||
"""
|
||||
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
|
||||
select_keys = [
|
||||
"responses",
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"response_mask",
|
||||
"position_ids",
|
||||
"old_log_probs",
|
||||
"advantages",
|
||||
]
|
||||
if self.config.use_kl_loss:
|
||||
select_keys.append("ref_log_prob")
|
||||
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
|
||||
@ -395,8 +403,7 @@ class MegatronPPOActor(BasePPOActor):
|
||||
|
||||
responses = data["responses"]
|
||||
response_length = responses.size(1)
|
||||
attention_mask = data["attention_mask"].to(bool)
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
response_mask = data["response_mask"].to(bool)
|
||||
loss_agg_mode = self.config.loss_agg_mode
|
||||
|
||||
# compute policy loss
|
||||
|
@ -31,7 +31,7 @@ from verl.utils.profiler import GPUMemoryLogger
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
|
||||
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
|
||||
from verl.workers.critic import BasePPOCritic
|
||||
|
||||
if is_cuda_available:
|
||||
@ -113,7 +113,7 @@ class DataParallelPPOCritic(BasePPOCritic):
|
||||
|
||||
# gather output if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
values_rmpad = gather_outpus_and_unpad(
|
||||
values_rmpad = gather_outputs_and_unpad(
|
||||
values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
|
||||
)
|
||||
|
||||
|
@ -1438,7 +1438,7 @@ class RewardModelWorker(Worker, DistProfilerExtension):
|
||||
unpad_input,
|
||||
)
|
||||
|
||||
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
|
||||
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):
|
||||
input_ids = micro_batch["input_ids"]
|
||||
@ -1481,7 +1481,7 @@ class RewardModelWorker(Worker, DistProfilerExtension):
|
||||
|
||||
# gather output if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
reward_rmpad = gather_outpus_and_unpad(
|
||||
reward_rmpad = gather_outputs_and_unpad(
|
||||
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
|
||||
)
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pickle
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Iterator, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -66,3 +66,43 @@ def broadcast_pyobj(
|
||||
serialized_data = bytes(tensor_data.cpu().numpy())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
|
||||
def get_named_tensor_buckets(
|
||||
iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int
|
||||
) -> Iterator[list[tuple[str, torch.Tensor]]]:
|
||||
"""
|
||||
Group tensors into buckets based on a specified size in megabytes.
|
||||
|
||||
Args:
|
||||
iterable: An iterator of tuples containing tensor names and tensors.
|
||||
bucket_bytes: The maximum size of each bucket in bytes.
|
||||
|
||||
Yields:
|
||||
Lists of tuples, where each tuple contains a tensor name and its corresponding tensor.
|
||||
|
||||
Example:
|
||||
>>> tensors = [('tensor1', torch.randn(1000, 1000)), ('tensor2', torch.randn(2000, 2000))]
|
||||
>>> for bucket in get_named_tensor_buckets(tensors, bucket_size_mb=10):
|
||||
... print(bucket)
|
||||
[('tensor1', tensor(...)), ('tensor2', tensor(...))]
|
||||
|
||||
"""
|
||||
if bucket_bytes <= 0:
|
||||
raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}")
|
||||
|
||||
current_bucket = []
|
||||
current_size = 0
|
||||
for name, tensor in iterable:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if current_size + tensor_size > bucket_bytes:
|
||||
if current_bucket:
|
||||
yield current_bucket
|
||||
current_bucket = [(name, tensor)]
|
||||
current_size = tensor_size
|
||||
else:
|
||||
current_bucket.append((name, tensor))
|
||||
current_size += tensor_size
|
||||
|
||||
if current_bucket:
|
||||
yield current_bucket
|
||||
|
@ -35,6 +35,7 @@ from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_
|
||||
from verl.utils.model import convert_weight_keys
|
||||
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
|
||||
from verl.utils.torch_functional import check_device_is_available
|
||||
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
|
||||
|
||||
from .base import BaseShardingManager
|
||||
|
||||
@ -113,32 +114,63 @@ class FSDPSGLangShardingManager(BaseShardingManager):
|
||||
# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
|
||||
named_tensors = [(k, v) for k, v in params.items()]
|
||||
load_format = None
|
||||
for tensor_index, (name, tensor) in enumerate(named_tensors):
|
||||
serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor))
|
||||
# convert megabytes to bytes
|
||||
update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20
|
||||
for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):
|
||||
# On each rank, serialize a batch of (name, tensor) tuples.
|
||||
# named_tensors_batch will be a list like:
|
||||
# [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...]
|
||||
named_tensors_batch = [
|
||||
(name, MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor)))
|
||||
for name, tensor in batch
|
||||
]
|
||||
|
||||
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
||||
gathered_serialized_tensors = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])]
|
||||
# On rank 0, prepare a list to hold the gathered batches from all ranks.
|
||||
gathered_serialized_batches = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])]
|
||||
else:
|
||||
gathered_serialized_tensors = None
|
||||
gathered_serialized_batches = None
|
||||
|
||||
# Gather the named_tensors_batch from all ranks to rank 0.
|
||||
# After this, on rank 0, gathered_serialized_batches will be a list of lists:
|
||||
# [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0
|
||||
# [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1
|
||||
# ... ]
|
||||
# On other ranks, gathered_serialized_batches will be None.
|
||||
dist.gather_object(
|
||||
obj=serialized_tensor,
|
||||
object_gather_list=gathered_serialized_tensors,
|
||||
obj=named_tensors_batch,
|
||||
object_gather_list=gathered_serialized_batches,
|
||||
dst=self.device_mesh["infer_tp"].mesh.tolist()[0],
|
||||
group=self.device_mesh["infer_tp"].get_group(),
|
||||
)
|
||||
|
||||
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
||||
# Use zip(*) to "transpose" the data structure.
|
||||
# This groups the serialized parts for each individual tensor across all TP ranks.
|
||||
# Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]]
|
||||
# to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ]
|
||||
logical_tensors = zip(*gathered_serialized_batches, strict=True)
|
||||
|
||||
await self.inference_engine.update_weights_from_tensor(
|
||||
named_tensors=[
|
||||
# 'tensor_group' represents a single logical tensor's data from all ranks.
|
||||
(
|
||||
name,
|
||||
LocalSerializedTensor(values=gathered_serialized_tensors),
|
||||
tensor_group[0][0], # Get the name from the first rank's data.
|
||||
LocalSerializedTensor(
|
||||
# 'rank_part' is the (name, serialized_tensor) tuple from one specific rank.
|
||||
values=[rank_part[1] for rank_part in tensor_group]
|
||||
),
|
||||
)
|
||||
for tensor_group in logical_tensors
|
||||
# each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) )
|
||||
],
|
||||
load_format=load_format,
|
||||
flush_cache=tensor_index == len(named_tensors) - 1,
|
||||
flush_cache=False,
|
||||
)
|
||||
|
||||
if self.device_mesh["infer_tp"].get_local_rank() == 0:
|
||||
await self.inference_engine.flush_cache()
|
||||
|
||||
async def release_memory(self):
|
||||
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
|
||||
await self.inference_engine.release_memory_occupation()
|
||||
|
@ -37,6 +37,7 @@ from verl.utils.megatron_utils import (
|
||||
per_tensor_generator,
|
||||
)
|
||||
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
|
||||
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
|
||||
|
||||
from .base import BaseShardingManager
|
||||
|
||||
@ -130,37 +131,76 @@ class MegatronSGLangShardingManager(BaseShardingManager):
|
||||
loop.run_until_complete(self.sleep())
|
||||
|
||||
async def update_weights(self, params):
|
||||
"""
|
||||
Update model weights using tensor buckets, similar to THUDM/slime's implementation.
|
||||
|
||||
Notes:
|
||||
- For the best performance of `rebuild_cuda_tensor`, it is recommended to:
|
||||
1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.
|
||||
2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`
|
||||
when using Tensor Parallelism (TP >= 8).
|
||||
- See reference implementations in SLIME:
|
||||
- Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
|
||||
- runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
|
||||
"""
|
||||
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
|
||||
await self.inference_engine.resume_memory_occupation()
|
||||
named_tensors = params
|
||||
load_format = None
|
||||
for tensor_index, (name, tensor) in enumerate(named_tensors):
|
||||
serialized_tensor = MultiprocessingSerializer.serialize(tensor.detach())
|
||||
|
||||
update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20
|
||||
for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):
|
||||
# On each rank, serialize a batch of (name, tensor) tuples.
|
||||
# named_tensors_batch will be a list like:
|
||||
# [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...]
|
||||
named_tensors_batch = [
|
||||
(name, MultiprocessingSerializer.serialize(tensor.detach())) for name, tensor in batch
|
||||
]
|
||||
|
||||
if self.device_mesh["tp"].get_local_rank() == 0:
|
||||
gathered_serialized_tensors = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])]
|
||||
# On rank 0, prepare a list to hold the gathered batches from all ranks.
|
||||
gathered_serialized_batches = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])]
|
||||
else:
|
||||
gathered_serialized_tensors = None
|
||||
gathered_serialized_batches = None
|
||||
|
||||
# Gather the named_tensors_batch from all ranks to rank 0.
|
||||
# After this, on rank 0, gathered_serialized_batches will be a list of lists:
|
||||
# [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0
|
||||
# [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1
|
||||
# ... ]
|
||||
# On other ranks, gathered_serialized_batches will be None.
|
||||
dist.gather_object(
|
||||
obj=serialized_tensor,
|
||||
object_gather_list=gathered_serialized_tensors,
|
||||
obj=named_tensors_batch,
|
||||
object_gather_list=gathered_serialized_batches,
|
||||
dst=self.device_mesh["tp"].mesh.tolist()[0],
|
||||
group=self.device_mesh["tp"].get_group(),
|
||||
)
|
||||
|
||||
if self.device_mesh["tp"].get_local_rank() == 0:
|
||||
# Use zip(*) to "transpose" the data structure.
|
||||
# This groups the serialized parts for each individual tensor across all TP ranks.
|
||||
# Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]]
|
||||
# to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ]
|
||||
logical_tensors = zip(*gathered_serialized_batches, strict=False)
|
||||
await self.inference_engine.update_weights_from_tensor(
|
||||
named_tensors=[
|
||||
# 'tensor_group' represents a single logical tensor's data from all ranks.
|
||||
(
|
||||
name,
|
||||
LocalSerializedTensor(values=gathered_serialized_tensors),
|
||||
tensor_group[0][0], # Get the name from the first rank's data.
|
||||
LocalSerializedTensor(
|
||||
# 'rank_part' is the (name, serialized_tensor) tuple from one specific rank.
|
||||
values=[rank_part[1] for rank_part in tensor_group]
|
||||
),
|
||||
)
|
||||
for tensor_group in logical_tensors
|
||||
# each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) )
|
||||
],
|
||||
load_format=load_format,
|
||||
flush_cache=False,
|
||||
)
|
||||
if self.device_mesh["tp"].get_local_rank() == 0:
|
||||
await self.inference_engine.flush_cache()
|
||||
|
||||
if self.device_mesh["tp"].get_local_rank() == 0:
|
||||
await self.inference_engine.flush_cache()
|
||||
|
||||
async def release_memory(self):
|
||||
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
|
||||
|
Reference in New Issue
Block a user