mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
[recipe] feat: CollabLLM integration for multiturn training (#3574)
### What does this PR do? This PR add [CollabLLM](https://aka.ms/CollabLLM) as a training recipe. The added components include - A customized `CollabLLMRewardManager` inheriting from `AbstractRewardManager` to compute multiturn-aware rewards. - A customized `CollabLLMAgentLoop` inheriting from `AgentLoop` to sample future conversations with simulated users, which imports `CollabLLMInteraction` from `verl/interactions/collabllm_interation.py`. ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. The training rewards when running `train_rl_collabllm.sh` is increasing in a relatively stable manner (on 8xH200): <img width="964" height="480" alt="9baeb0700e3fa6a56596e14a54bc1049" src="https://github.com/user-attachments/assets/53a810d8-1dd7-4145-bb28-4e475e9d7d9d" /> Validation reward: <img width="974" height="538" alt="39364fd10523b0fde13d48645809f5e3" src="https://github.com/user-attachments/assets/c34fe9e7-3d83-4132-8e1a-67e82c221d09" /> #### Samples of model generation After training, when user asks generic questions with missing information, the model learns to ask for clarification <img width="1213" height="562" alt="c8e0ab31948a48ca396c7eccddd13673" src="https://github.com/user-attachments/assets/ae41cd77-3c77-4402-b9d3-21993b046a18" /> and give suggestions: <img width="1534" height="190" alt="7adb7d33eb9120d337c2a249c6a2dd22" src="https://github.com/user-attachments/assets/84e1d8c1-f954-403f-b931-bce45cff1612" /> (In contrast, with the same prompt, **GPT-5** doesn't ask for any clarification:) <img width="1754" height="1126" alt="be8d8577584c0b2356cb352d6f294205" src="https://github.com/user-attachments/assets/9b734848-9ed0-4496-af11-68bb8f8d8e08" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # No change on the existing APIs ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. Changes: - Main files under `recipe/collabllm` - Registered `CollabLLMRewardManager` in `workers/reward_manager/collabllm.py` - Added `CollabLLMInteraction` in `verl/interactions/collabllm_interation.py` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). Added to `verl/docs/algo/collabllm.md`. - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: The scripts `train_rl_collabllm.sh` and `train_sft_collabllm.sh` are tested multiple times. - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Chen Haiquan <chenhaiquan@bytedance.com>
This commit is contained in:
105
docs/algo/collabllm.md
Normal file
105
docs/algo/collabllm.md
Normal file
@ -0,0 +1,105 @@
|
||||
# Recipe: CollabLLM
|
||||
|
||||
Last updated: 09/22/2025.
|
||||
|
||||
> Open-Source Algorithm Implementation & Expriement Running: [Haiquan Chen](https://github.com/chenhaiq), [Shirley Wu](https://github.com/Wuyxin)
|
||||
|
||||
🏠 [Homepage](https://aka.ms/CollabLLM) | 📝 [Paper](https://arxiv.org/pdf/2502.00640) | 🤗 [Datasets & Models](https://huggingface.co/collabllm) | ⭐️ [Original Implementation](https://github.com/Wuyxin/collabllm)
|
||||
|
||||
`verl` provides a recipe for the Outstanding Paper at ICML 2025, **"CollabLLM: From Passive Responders to Active Collaborators"**. [CollabLLM](https://aka.ms/CollabLLM) is a unified fine-tuning framework that optimizes LLMs for effective and efficient multiturn collaboration with users.
|
||||
|
||||
**Core Idea:** Models are rewarded based on how well their responses enable effective *future* collaboration with users.
|
||||
|
||||
Paper Authors: [Shirley Wu](https://cs.stanford.edu/~shirwu/), [Michel Galley](https://www.microsoft.com/en-us/research/people/mgalley/), Baolin Peng, Hao Cheng, Gavin Li, Yao Dou, Weixin Cai, [James Zou](https://www.james-zou.com/), [Jure Leskovec](https://cs.stanford.edu/people/jure/), [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/)
|
||||
|
||||
|
||||
---
|
||||
## Quick Start
|
||||
|
||||
### 0. Environment
|
||||
Make sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).
|
||||
|
||||
### 1. Prepare Your Dataset
|
||||
|
||||
First, process your dataset using the provided script (see example commands and usage in `process_dataset.py`):
|
||||
|
||||
```bash
|
||||
python process_dataset.py --dataset <> ... --dataset_type <sft or rl>
|
||||
```
|
||||
|
||||
|
||||
**Requirements:**
|
||||
- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)
|
||||
- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)
|
||||
- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository
|
||||
|
||||
|
||||
### 2. Train Your Model
|
||||
|
||||
**(Optional) For Supervised Fine-Tuning (SFT):**
|
||||
```bash
|
||||
bash train_sft_collabllm.sh
|
||||
```
|
||||
|
||||
**For Reinforcement Learning (RL):**
|
||||
|
||||
```bash
|
||||
bash train_rl_collabllm.sh
|
||||
```
|
||||
|
||||
The RL script shows an example to train CollabLLM on `math-hard-large`.
|
||||
|
||||
- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`.
|
||||
- The Multiturn-aware Reward is aggregated from these three conversational-level rewards:
|
||||
|
||||
```
|
||||
+reward_model.reward_kwargs.metric_weights.accuracy=1 \
|
||||
+reward_model.reward_kwargs.metric_weights.interactivity=1 \
|
||||
+reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
|
||||
```
|
||||
|
||||
You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via
|
||||
```
|
||||
+reward_model.reward_kwargs.metric_weights.bleu_score=1
|
||||
```
|
||||
which will instead apply bleu score on the sampled future conversations.
|
||||
|
||||
## Algorithm
|
||||
|
||||
| Step | Name | Description |
|
||||
|------|-------------------------------|-----------------------------------------------------------------------------|
|
||||
| 1 | Model response generation | The model generates multiple responses for each prompt in a batch. |
|
||||
| 2 | Collaborative simulation | A user simulator (e.g., GPT or Claude) samples `num_repeat_rollouts` conversations for up to `max_user_turns` additional turns. |
|
||||
| 3 | Compute Multiturn-aware Reward | Customized conversational reward functions are applied to the sampled conversations. Rewards are aggregated, then averaged across rollouts. |
|
||||
| 4 | Update model | The model weights are updated using the computed multiturn-aware rewards. |
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
The primary configuration is managed through the launch script `train_rl_collabllm.sh` and the YAML file `recipe/collabllm/config/collabllm_interaction_config.yaml`. Key configuration sections:
|
||||
|
||||
| Section | Key Parameters / Notes |
|
||||
|----------------------|-----------------------------------------------------------------------------------------|
|
||||
| `data` | Paths to training/validation files, batch sizes, sequence lengths. |
|
||||
| `actor_rollout_ref` (common) | Base model path (used for actor + initial reference), FSDP settings, optimization (LR, scheduler). |
|
||||
| `actor_rollout_ref` (CollabLLM-specific) | Hyperparameters under `actor_rollout_ref.rollout.multi_turn`: `max_user_turns`, `max_assistant_turns`, `num_repeat_rollouts`. |
|
||||
| `interaction` | Defined in `collabllm_interaction_config.yaml`. Specifies user simulator and hyperparameters. Requires exported API keys. |
|
||||
| `reward_model` | Manager set to `collabllm` by default. Modify `reward_model.reward_kwargs.metric_weights` for conversational rewards and weights. LLM Judge hyperparameters (e.g., `model`, `temperature`) go under `reward_model.reward_kwargs.llm_judge_kwargs`. |
|
||||
| `algorithm` | GRPO-specific hyperparameters such as `actor_rollout_ref.rollout.n`. |
|
||||
| `trainer` | Distributed training (nodes, GPUs per node), logging (WandB), checkpointing frequency. |
|
||||
|
||||
---
|
||||
|
||||
## Key Files
|
||||
|
||||
| File Path | Purpose |
|
||||
|-----------|---------|
|
||||
| `recipe/collabllm/collabllm_agent_loop.py` | Main logic to sample future conversations, using `CollabLLMInteraction` from `verl/interactions/collabllm_interaction.py`. |
|
||||
| `verl/workers/reward_manager/collabllm.py` | Computes rewards for future conversations, leveraging `recipe/collabllm/reward_function.py` to apply each metric. |
|
||||
|
||||
---
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
We sincerely thank the `verl` community and advisors for their contributions and guidance!
|
@ -70,6 +70,7 @@ verl is fast with:
|
||||
|
||||
algo/ppo.md
|
||||
algo/grpo.md
|
||||
algo/collabllm.md
|
||||
algo/dapo.md
|
||||
algo/spin.md
|
||||
algo/sppo.md
|
||||
|
@ -1,23 +1,22 @@
|
||||
set -x
|
||||
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=grpo \
|
||||
trainer.val_before_train=False \
|
||||
data.train_files=$HOME/data/gsm8k/train.parquet \
|
||||
data.val_files=$HOME/data/gsm8k/test.parquet \
|
||||
data.train_batch_size=1024 \
|
||||
data.train_batch_size=16 \
|
||||
data.max_prompt_length=512 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
data.shuffle=False \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
|
||||
actor_rollout_ref.model.use_shm=True \
|
||||
actor_rollout_ref.model.lora_rank=64 \
|
||||
actor_rollout_ref.model.lora_alpha=32 \
|
||||
actor_rollout_ref.actor.optim.lr=3e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
@ -40,8 +39,13 @@ python3 -m verl.trainer.main_ppo \
|
||||
trainer.logger='["console","wandb"]' \
|
||||
trainer.project_name='verl_grpo_example_gsm8k' \
|
||||
trainer.experiment_name='qwen2.5_3b_grpo_lora' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.n_gpus_per_node=2 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=20 \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=15 $@
|
||||
|
||||
# actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
# data.train_batch_size=1024 \
|
||||
# trainer.n_gpus_per_node=8 \
|
||||
# actor_rollout_ref.model.use_shm=True \
|
||||
|
@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=4
|
||||
NOW=$(date +%Y%m%d)
|
||||
export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-0.5b-${NOW}
|
||||
export WANDB_PROJECT=${WANDB_DIR}
|
||||
@ -7,7 +7,7 @@ export WANDB_EXP=0.5b-${NOW}
|
||||
MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct
|
||||
|
||||
set -x
|
||||
nproc_per_gpu=116
|
||||
nproc_per_gpu=1
|
||||
nnodes=1
|
||||
ngpu_per_node=1
|
||||
total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))
|
||||
@ -15,8 +15,9 @@ mini_batch_size=$(( total_procs ))
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files=data/gsm8k/train.parquet \
|
||||
data.val_files=data/gsm8k/test.parquet \
|
||||
trainer.val_before_train=False \
|
||||
data.train_files=$HOME/data/gsm8k/train.parquet \
|
||||
data.val_files=$HOME/data/gsm8k/test.parquet \
|
||||
data.train_batch_size=${total_procs} \
|
||||
data.val_batch_size=${total_procs} \
|
||||
data.max_prompt_length=512 \
|
||||
@ -25,7 +26,6 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.truncation='error' \
|
||||
data.shuffle=False \
|
||||
actor_rollout_ref.model.path=$MODEL_PATH \
|
||||
actor_rollout_ref.model.use_shm=True \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.model.lora_rank=32 \
|
||||
actor_rollout_ref.model.lora_alpha=32 \
|
||||
@ -33,7 +33,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.actor.optim.lr=3e-5 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${mini_batch_size} \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
@ -45,7 +45,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \
|
||||
actor_rollout_ref.rollout.n=5 \
|
||||
actor_rollout_ref.rollout.n=1 \
|
||||
actor_rollout_ref.rollout.max_num_seqs=512 \
|
||||
actor_rollout_ref.rollout.max_model_len=1536 \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=1536 \
|
||||
|
74
recipe/collabllm/README.md
Normal file
74
recipe/collabllm/README.md
Normal file
@ -0,0 +1,74 @@
|
||||
# CollabLLM
|
||||
|
||||
This repository implements [CollabLLM](https://arxiv.org/pdf/2502.00640) (ICML 2025) using the verl framework. For the original implementation, see the [CollabLLM repository](https://github.com/Wuyxin/collabllm).
|
||||
|
||||
|
||||
CollabLLM is a method for training language models to collaborate effectively in multi-turn conversations. This implementation adapts the original imlpementation to work with the Verl training framework.
|
||||
|
||||
## Quick start
|
||||
|
||||
### 0. Environment
|
||||
Make sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).
|
||||
|
||||
### 1. Prepare Your Dataset
|
||||
|
||||
First, process your dataset using the provided script:
|
||||
|
||||
```bash
|
||||
python process_dataset.py --dataset <> ... --dataset_type <sft or rl>
|
||||
```
|
||||
|
||||
|
||||
**Requirements:**
|
||||
- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)
|
||||
- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)
|
||||
- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository
|
||||
|
||||
*Note: Check `process_dataset.py` for example commands and usage.*
|
||||
|
||||
### 2. Train Your Model
|
||||
|
||||
**(Optional) For Supervised Fine-Tuning (SFT):**
|
||||
```bash
|
||||
bash train_sft_collabllm.sh
|
||||
```
|
||||
|
||||
**For Reinforcement Learning (RL):**
|
||||
|
||||
```bash
|
||||
bash train_rl_collabllm.sh
|
||||
```
|
||||
|
||||
The RL script shows an example to train CollabLLM on `math-hard-large`.
|
||||
|
||||
- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`.
|
||||
- The Multiturn-aware Reward is aggregated from these three conversational-level rewards:
|
||||
|
||||
```
|
||||
+reward_model.reward_kwargs.metric_weights.accuracy=1 \
|
||||
+reward_model.reward_kwargs.metric_weights.interactivity=1 \
|
||||
+reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
|
||||
```
|
||||
|
||||
You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via
|
||||
```
|
||||
+reward_model.reward_kwargs.metric_weights.bleu_score=1
|
||||
```
|
||||
which will instead apply bleu score on the sampled future conversations.
|
||||
|
||||
## Configuration
|
||||
Read [doc](https://verl.readthedocs.io/en/latest/) for detailed configurations.
|
||||
|
||||
## Citation
|
||||
If you find CollabLLM useful in your research, please cite the following:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{collabllm2025,
|
||||
title={CollabLLM: From Passive Responders to Active Collaborators},
|
||||
author={Shirley Wu and Michel Galley and Baolin Peng and Hao Cheng and
|
||||
Gavin Li and Yao Dou and Weixin Cai and James Zou and
|
||||
Jure Leskovec and Jianfeng Gao},
|
||||
booktitle={International Conference on Machine Learning (ICML)},
|
||||
year={2025}
|
||||
}
|
||||
```
|
139
recipe/collabllm/collabllm_agent_loop.py
Normal file
139
recipe/collabllm/collabllm_agent_loop.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from recipe.collabllm.utils import is_valid_messages
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register
|
||||
from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop
|
||||
from verl.utils.rollout_trace import rollout_trace_op
|
||||
from verl.workers.rollout.schemas import Message
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
@register("collabllm_agent")
|
||||
class CollabLLMAgentLoop(ToolAgentLoop):
|
||||
@rollout_trace_op
|
||||
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
||||
messages = list(kwargs["raw_prompt"])
|
||||
image_data = deepcopy(kwargs.get("multi_modal_data", {}).get("image", None))
|
||||
metrics = {}
|
||||
request_id = uuid4().hex
|
||||
tools_kwargs = kwargs.get("tools_kwargs", {})
|
||||
|
||||
# Initialize interaction if needed
|
||||
interaction = None
|
||||
interaction_kwargs = {}
|
||||
if self.interaction_config_file:
|
||||
interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"]
|
||||
if "name" not in interaction_kwargs:
|
||||
raise ValueError("'name' key is required in interaction_kwargs")
|
||||
interaction_name = interaction_kwargs["name"]
|
||||
if interaction_name not in self.interaction_map:
|
||||
raise ValueError(
|
||||
f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: "
|
||||
f"{list(self.interaction_map.keys())}"
|
||||
)
|
||||
interaction = self.interaction_map[interaction_name]
|
||||
await interaction.start_interaction(request_id, **interaction_kwargs)
|
||||
# Create AgentData instance to encapsulate all state
|
||||
agent_data = AgentData(
|
||||
messages=messages,
|
||||
image_data=image_data,
|
||||
metrics=metrics,
|
||||
request_id=request_id,
|
||||
tools_kwargs=tools_kwargs,
|
||||
interaction=interaction,
|
||||
interaction_kwargs=interaction_kwargs,
|
||||
)
|
||||
# for collabllm, firstly generate model reponses
|
||||
await self._handle_pending_state(agent_data, sampling_params)
|
||||
|
||||
status = await self._handle_generating_state(agent_data, sampling_params)
|
||||
|
||||
if status == AgentState.TERMINATED:
|
||||
# tell reward manager to score -1 and skip future interaction
|
||||
# to avoid reward hacking with incompleted message
|
||||
num_repeats = 0
|
||||
else:
|
||||
# then, collect interaction rollouts
|
||||
num_repeats = self.config.actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts
|
||||
|
||||
interaction_requests = [deepcopy(agent_data) for _ in range(num_repeats)]
|
||||
|
||||
# messages are only used in collabllm reward manager
|
||||
messages_lst = []
|
||||
for _agent_data in interaction_requests:
|
||||
if not is_valid_messages(_agent_data.messages[-1]):
|
||||
break
|
||||
|
||||
prev_msg_len = len(_agent_data.messages)
|
||||
await self.run_agent_data_loop(_agent_data, sampling_params, AgentState.INTERACTING)
|
||||
messages_lst.append([Message(**msg) for msg in _agent_data.messages])
|
||||
|
||||
if interaction.config.get("enable_log"):
|
||||
print(f"Assistant: ...{messages_lst[-1][prev_msg_len - 1].content[-100:]}")
|
||||
print(f"User: {messages_lst[-1][prev_msg_len].content[:100]}...")
|
||||
|
||||
# Finalize output
|
||||
response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]
|
||||
prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]
|
||||
multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {}
|
||||
|
||||
output = AgentLoopOutput(
|
||||
prompt_ids=prompt_ids,
|
||||
response_ids=response_ids[: self.response_length],
|
||||
response_mask=agent_data.response_mask[: self.response_length],
|
||||
multi_modal_data=multi_modal_data,
|
||||
response_logprobs=agent_data.response_logprobs[: self.response_length]
|
||||
if agent_data.response_logprobs
|
||||
else None,
|
||||
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
|
||||
metrics=agent_data.metrics,
|
||||
extra_fields={
|
||||
"turn_scores": agent_data.turn_scores,
|
||||
"messages": {"messages": messages_lst}, # compatiable with sglang interaction
|
||||
},
|
||||
)
|
||||
return output
|
||||
|
||||
async def run_agent_data_loop(self, agent_data: AgentData, sampling_params: dict[str, Any], state: AgentState):
|
||||
"""
|
||||
Run the agent data loop to process the agent data.
|
||||
|
||||
Args:
|
||||
agent_data (AgentData): The agent data to process.
|
||||
sampling_params (dict[str, Any]): The sampling parameters.
|
||||
state (AgentState, optional): The initial state of the agent. Defaults to None.
|
||||
"""
|
||||
|
||||
while state != AgentState.TERMINATED:
|
||||
if state == AgentState.PENDING:
|
||||
state = await self._handle_pending_state(agent_data, sampling_params)
|
||||
elif state == AgentState.GENERATING:
|
||||
state = await self._handle_generating_state(agent_data, sampling_params)
|
||||
elif state == AgentState.PROCESSING_TOOLS:
|
||||
state = await self._handle_processing_tools_state(agent_data)
|
||||
elif state == AgentState.INTERACTING:
|
||||
state = await self._handle_interacting_state(agent_data)
|
||||
else:
|
||||
logger.error(f"Invalid state: {state}")
|
||||
state = AgentState.TERMINATED
|
10
recipe/collabllm/config/collabllm_interaction_config.yaml
Normal file
10
recipe/collabllm/config/collabllm_interaction_config.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
interaction:
|
||||
- name: "collabllm"
|
||||
class_name: "verl.interactions.collabllm_interation.CollabLLMInteraction"
|
||||
config: {
|
||||
"user_model": "gpt-4o-mini",
|
||||
"num_retries": 3,
|
||||
"max_tokens": 512,
|
||||
"temperature": 1.0,
|
||||
"enable_log": True
|
||||
}
|
104
recipe/collabllm/metrics/accuracy.py
Normal file
104
recipe/collabllm/metrics/accuracy.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from recipe.collabllm.utils import extract_json, parse_messages
|
||||
|
||||
ACCURACY_PROMPT = '''You are a helpful and meticulous evaluator. Your task is to \
|
||||
evaluate the *accuracy* of an AI model's answer to a target question. \
|
||||
You will be given the target question, the ground truth answer, and the conversation between the AI and the user.
|
||||
|
||||
Provided Information:
|
||||
|
||||
<|The Start of Target Question and Ground Truth Answer|>
|
||||
Target Question: {single_turn_prompt}
|
||||
Ground Truth Answer: {ground_truth}
|
||||
<|The End of Target Question and Ground Truth Answer|>
|
||||
|
||||
<|The Start of The Conversation|>
|
||||
{chat_history}
|
||||
<|The End of The Conversation|>
|
||||
|
||||
You should determine whether the model's final response to the target question is \
|
||||
factually correct and consistent with the provided ground truth.
|
||||
|
||||
Rating criteria (binary):
|
||||
• 1 = Correct — the response matches the ground truth.
|
||||
• 0 = Incorrect — the response contradicts or misses the ground truth.
|
||||
|
||||
Output format (JSON):
|
||||
{{
|
||||
"thought": "<your reasoning here>",
|
||||
"accuracy": <0 or 1>
|
||||
}}
|
||||
|
||||
Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \
|
||||
Use " or """ to wrap up the thought and use single quotes inside the "thought" field to avoid JSON escape issues.
|
||||
|
||||
Your evaluation:
|
||||
'''
|
||||
|
||||
|
||||
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
|
||||
# Check if litellm is available, fallback to openai if not
|
||||
try:
|
||||
import litellm
|
||||
|
||||
use_litellm = True
|
||||
except ImportError:
|
||||
# litellm not found, falling back to openai
|
||||
import openai
|
||||
|
||||
use_litellm = False
|
||||
|
||||
chat_history = parse_messages(messages, strip_sys_prompt=True)
|
||||
prompt = ACCURACY_PROMPT.format(
|
||||
single_turn_prompt=extra_info["interaction_kwargs"]["single_turn_prompt"],
|
||||
ground_truth=ground_truth,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
|
||||
if use_litellm:
|
||||
full_response = (
|
||||
(
|
||||
await litellm.acompletion(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
else:
|
||||
client = openai.AsyncOpenAI() # Assumes API key is set in environment
|
||||
full_response = (
|
||||
(
|
||||
await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
full_response = extract_json(full_response)
|
||||
|
||||
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
|
||||
assert {"accuracy", "thought"}.issubset(full_response.keys()), (
|
||||
f"Expected keys not found from {full_response.keys()}"
|
||||
)
|
||||
|
||||
accuracy = full_response.pop("accuracy")
|
||||
return float(accuracy)
|
116
recipe/collabllm/metrics/bleu_score.py
Normal file
116
recipe/collabllm/metrics/bleu_score.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from nltk.translate.bleu_score import sentence_bleu
|
||||
|
||||
from recipe.collabllm.utils import extract_json, parse_messages
|
||||
|
||||
EXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \
|
||||
Your task is to extract the final and complete version of a document that was generated during \
|
||||
a multiturn conversation between a user and a chat assistant. \
|
||||
The extracted content should reflect the final and comprehensive response provided by the assistant \
|
||||
based on the user’s request.
|
||||
|
||||
You will be provided with the conversation:
|
||||
|
||||
<|The Start of The Conversation|>
|
||||
{chat_history}
|
||||
<|The End of The Conversation|>
|
||||
|
||||
Instructions for Extraction:
|
||||
|
||||
1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts \
|
||||
of the content provided by the assistant. This may include:
|
||||
- Different sections of text (e.g., an essay, report, or article).
|
||||
|
||||
2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \
|
||||
ensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \
|
||||
output that incorporates all modifications and additions made during the conversation. For example, if the assistant \
|
||||
writes an introducation at the beginning and move on to the conclusion, the final output should include both the \
|
||||
introduction and the conclusion.
|
||||
|
||||
3. Focus on Completeness:
|
||||
- For text-based documents: Ensure that the extracted content is comprehensive and represents the full document \
|
||||
or section as discussed in the conversation.
|
||||
|
||||
You should output a JSON object with two entries:
|
||||
- "thought" (str): Output your thought process when extracting the final content.
|
||||
1. How do different parts of the conversation contribute to the final output?
|
||||
2. How do you make sure you included the most updated and complete information?
|
||||
3. How do you make sure you did not include any information that is not necessary?
|
||||
- "final_completion" (str): The final and complete version of the document extracted from the conversation.
|
||||
|
||||
Note:
|
||||
1. If there are multiple lines, you should use triple quotes (""") to wrap the content. For example, \
|
||||
"final_completion": """first line.
|
||||
second line.""" or "thought": """first line;
|
||||
second line.""".
|
||||
2. In the "final_completion" entry, replace all double quotes (") with single quotes (') to prevent JSON formatting \
|
||||
issues. For example, you can output "final_completion": "'Hello World' is a common phrase."
|
||||
|
||||
Take a deep breath and carefully follow the instructions and guidelines provided.
|
||||
'''
|
||||
|
||||
|
||||
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
|
||||
# Check if litellm is available, fallback to openai if not
|
||||
try:
|
||||
import litellm
|
||||
|
||||
use_litellm = True
|
||||
except ImportError:
|
||||
# litellm not found, falling back to openai
|
||||
import openai
|
||||
|
||||
use_litellm = False
|
||||
|
||||
chat_history = parse_messages(messages, strip_sys_prompt=True)
|
||||
prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format(chat_history=chat_history)
|
||||
|
||||
if use_litellm:
|
||||
full_response = (
|
||||
(
|
||||
await litellm.acompletion(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
else:
|
||||
client = openai.AsyncOpenAI() # Assumes API key is set in environment
|
||||
full_response = (
|
||||
(
|
||||
await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
full_response = extract_json(full_response)
|
||||
|
||||
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
|
||||
assert {"final_completion", "thought"}.issubset(full_response.keys()), (
|
||||
f"Expected keys not found from {full_response.keys()}"
|
||||
)
|
||||
|
||||
final_completion = full_response.pop("final_completion")
|
||||
|
||||
bleu = sentence_bleu([ground_truth], final_completion)
|
||||
return float(bleu)
|
108
recipe/collabllm/metrics/interactivity.py
Normal file
108
recipe/collabllm/metrics/interactivity.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from recipe.collabllm.utils import extract_json, parse_messages
|
||||
|
||||
INTERACTIVITY_PROMPT = '''You are a helpful and meticulous conversation evaluator. \
|
||||
Your task is to evaluate the interactivity of the responses provided by an AI assistant \
|
||||
to user questions in a given conversation:
|
||||
|
||||
<|The Start of the Conversation to be Evaluated|>
|
||||
{chat_history}
|
||||
<|The End of the Conversation to be Evaluated|>
|
||||
|
||||
You should assess the assistant's engagement, clarity, and ability to understand the user's needs. \
|
||||
Give a float number between 0 and 1.
|
||||
|
||||
Scoring Criteria:
|
||||
- Let U = user understanding & response clarity ∈ [0,1]
|
||||
- 1.0 = Fully understands the user's intent and gives a clear answer.
|
||||
- 0.7 = Mostly understands and the answer is generally clear.
|
||||
- 0.3 = Partially misunderstands or the answer is hard to follow.
|
||||
- 0.0 = Misunderstands the intent and gives an unclear or irrelevant answer.
|
||||
- Let Q = clarification in [0,1]
|
||||
- 1.0 = Asks precise, necessary clarifying questions when needed.
|
||||
- 0.7 = Asks somewhat helpful but incomplete clarifications.
|
||||
- 0.3 = Only asks generic questions (e.g., “Does that help?”).
|
||||
- 0.0 = Asks no clarifying questions when needed.
|
||||
- Let S = suggestion helpfulness in [0,1]
|
||||
- 1.0 = Provides useful, actionable suggestions.
|
||||
- 0.7 = Suggestions are somewhat helpful but limited.
|
||||
- 0.3 = Suggestions are vague or generic.
|
||||
- 0.0 = No suggestions when they would clearly help.
|
||||
score = average([U, Q, S])
|
||||
|
||||
Output format (JSON):
|
||||
{{
|
||||
"thought": "<How interactive is the assistant?>",
|
||||
"interactivity": <score>
|
||||
}}
|
||||
|
||||
Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \
|
||||
Use " or """ to wrap up the thought. You should not use other triple quotes inside the "thought" field. \
|
||||
Instead you should use single quotes to avoid JSON escape issues.
|
||||
|
||||
Your evaluation:
|
||||
'''
|
||||
|
||||
|
||||
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
|
||||
# Check if litellm is available, fallback to openai if not
|
||||
try:
|
||||
import litellm
|
||||
|
||||
use_litellm = True
|
||||
except ImportError:
|
||||
# litellm not found, falling back to openai
|
||||
import openai
|
||||
|
||||
use_litellm = False
|
||||
|
||||
chat_history = parse_messages(messages, strip_sys_prompt=True)
|
||||
prompt = INTERACTIVITY_PROMPT.format(chat_history=chat_history)
|
||||
|
||||
if use_litellm:
|
||||
full_response = (
|
||||
(
|
||||
await litellm.acompletion(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
else:
|
||||
client = openai.AsyncOpenAI() # Assumes API key is set in environment
|
||||
full_response = (
|
||||
(
|
||||
await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
full_response = extract_json(full_response)
|
||||
|
||||
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
|
||||
assert {"interactivity", "thought"}.issubset(full_response.keys()), (
|
||||
f"Expected keys not found from {full_response.keys()}"
|
||||
)
|
||||
|
||||
interactivity = full_response.pop("interactivity")
|
||||
return float(interactivity)
|
139
recipe/collabllm/metrics/pass_rate.py
Normal file
139
recipe/collabllm/metrics/pass_rate.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from bigcodebench.eval import untrusted_check
|
||||
|
||||
from recipe.collabllm.utils import extract_json, parse_messages
|
||||
|
||||
EXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \
|
||||
Your task is to extract the final and complete version of a code function {entry_point} that was generated \
|
||||
during a multiturn conversation between a user and a chat assistant. \
|
||||
The extracted content should reflect the final and comprehensive response provided by the \
|
||||
assistant based on the user’s request.
|
||||
|
||||
You will be provided with the task and the conversation:
|
||||
|
||||
<|The Start of The Task|>
|
||||
{single_turn_prompt}
|
||||
<|The End of The Task|>
|
||||
|
||||
<|The Start of The Conversation|>
|
||||
{chat_history}
|
||||
<|The End of The Conversation|>
|
||||
|
||||
Instructions for Extraction:
|
||||
|
||||
1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts of \
|
||||
the content provided by the assistant. This may include:
|
||||
- Different parts of the code snippet, function, class, or script.
|
||||
|
||||
2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \
|
||||
ensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \
|
||||
output that incorporates all modifications and additions made during the conversation. For example, if the assistant \
|
||||
writes a function at the beginning and changes a part, the final output should take the modification into account.
|
||||
|
||||
3. Focus on Completeness:
|
||||
- For code: Extract a complete and functional code snippet, including all necessary components such as imports, \
|
||||
functions, classes, and any other essential elements. The code should be runnable, but you do not need to \
|
||||
include any testing examples including the contents after `if __name__ == "__main__":`. Only the function code \
|
||||
is required.
|
||||
|
||||
You should output a JSON object with two entries:
|
||||
- "thought" (str): Output your thought process when extracting the final content.
|
||||
1. How do different parts of the conversation contribute to the final output?
|
||||
2. How do you make sure you included the most updated and complete information?
|
||||
3. How do you make sure you did not include any information that is not necessary?
|
||||
- "final_completion" (str): The final and complete version of the code extracted from the conversation. \
|
||||
Rename main function name for the task to {entry_point} if needed. Remove any comments wrapped by """.
|
||||
|
||||
Note:
|
||||
1. If there are multiple lines, you should use triple quotes (""") to wrap the content. For example, \
|
||||
"final_completion": """first line.
|
||||
second line.""" or "thought": """first line;
|
||||
second line.""". You should not use other triple quotes inside.
|
||||
2. In the "final_completion" entry, replace all double quotes (") with single quotes (') to prevent JSON formatting \
|
||||
issues. For example, you can output "final_completion": "'Hello World' is a common phrase."
|
||||
|
||||
Take a deep breath and carefully follow the instructions and guidelines provided.
|
||||
'''
|
||||
|
||||
|
||||
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
|
||||
# Check if litellm is available, fallback to openai if not
|
||||
try:
|
||||
import litellm
|
||||
|
||||
use_litellm = True
|
||||
except ImportError:
|
||||
# litellm not found, falling back to openai
|
||||
import openai
|
||||
|
||||
use_litellm = False
|
||||
|
||||
chat_history = parse_messages(messages, strip_sys_prompt=True)
|
||||
|
||||
prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format(
|
||||
chat_history=chat_history,
|
||||
single_turn_prompt=extra_info["interaction_kwargs"]["single_turn_prompt"],
|
||||
entry_point=extra_info["single_turn_metadata"]["entry_point"],
|
||||
)
|
||||
|
||||
if use_litellm:
|
||||
full_response = (
|
||||
(
|
||||
await litellm.acompletion(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
else:
|
||||
client = openai.AsyncOpenAI() # Assumes API key is set in environment
|
||||
full_response = (
|
||||
(
|
||||
await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
full_response = extract_json(full_response)
|
||||
|
||||
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
|
||||
assert {"final_completion", "thought"}.issubset(full_response.keys()), (
|
||||
f"Expected keys not found from {full_response.keys()}"
|
||||
)
|
||||
|
||||
final_completion = full_response.pop("final_completion")
|
||||
metadata = extra_info["single_turn_metadata"]
|
||||
res = untrusted_check(
|
||||
final_completion,
|
||||
metadata["test"],
|
||||
metadata["entry_point"],
|
||||
max_as_limit=300 * 1024,
|
||||
max_data_limit=300 * 1024,
|
||||
max_stack_limit=300 * 1024,
|
||||
min_time_limit=60,
|
||||
gt_time_limit=60,
|
||||
)
|
||||
passed = res[0] == "pass"
|
||||
|
||||
# info = res[1] # for printing extra info
|
||||
return float(passed)
|
26
recipe/collabllm/metrics/token_amount.py
Normal file
26
recipe/collabllm/metrics/token_amount.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
|
||||
prompt = extra_info["prompt"]
|
||||
|
||||
# Calculate the token penalty based on the length of the prompt
|
||||
future_conv = messages[len(prompt) :]
|
||||
|
||||
# simple length estimation
|
||||
total_tokens = sum(len(m.content.split()) for m in future_conv)
|
||||
|
||||
return total_tokens
|
239
recipe/collabllm/process_dataset.py
Normal file
239
recipe/collabllm/process_dataset.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
# available datasets:
|
||||
# math-hard(-large), medium(-large), bigcodebench(-large)
|
||||
# to create your own dataset, refer to https://github.com/Wuyxin/collabllm
|
||||
|
||||
DATASET=math-hard-large
|
||||
|
||||
python recipe/collabllm/process_dataset.py \
|
||||
--dataset collabllm/collabllm-multiturn-$DATASET \
|
||||
--local_dir $HOME/data/collabllm-$DATASET \
|
||||
--dataset_type sft
|
||||
|
||||
python recipe/collabllm/process_dataset.py \
|
||||
--dataset collabllm/collabllm-multiturn-$DATASET \
|
||||
--local_dir $HOME/data/collabllm-$DATASET \
|
||||
--dataset_type rl
|
||||
|
||||
|
||||
Preprocess collabllm/collabllm-multiturn-math-hard into (ground_truth, extra_info).
|
||||
|
||||
- ground_truth: picked from --prefer_field (default: single_turn_completion),
|
||||
falling back to --fallback_field (default: completion)
|
||||
- extra_info: a shallow copy of the original example plus bookkeeping fields
|
||||
- reward_model: {"style": "rule", "ground_truth": ground_truth}
|
||||
|
||||
Saves one parquet per split into --local_dir and a small JSON preview.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
|
||||
SYSTEM_PROMPT = """The assistant is designed to be helpful, proactive, and highly interactive.
|
||||
|
||||
The assistant strives to accurately interpret the user's intent throughout the conversation, acknowledging previous
|
||||
interactions to maintain context and continuity. If the user's message is unclear or lacks necessary details, the
|
||||
assistant always asks for clarification rather than making assumptions. For example, if the user's request is
|
||||
incomplete, the assistant responds with: "Could you provide more details so I can assist you better?"
|
||||
|
||||
The assistant asks specific follow-up questions and offers suggestions based on the user's needs, avoiding vague or
|
||||
generic prompts. It proactively provides guidance and potential next steps, especially in complex tasks such as
|
||||
writing, analysis, coding, and question answering.
|
||||
|
||||
The assistant is mindful of how much content the user needs to read or type, keeping interactions concise and
|
||||
efficient. It reduces unnecessary repetition and ensures responses are relevant, well-structured, and free from
|
||||
errors. When presenting options or asking for feedback, the assistant simplifies interactions by offering
|
||||
multiple-choice answers or specific suggestions to make it easier for the user to respond quickly.
|
||||
|
||||
The assistant adapts its tone to align with the user's emotional state and style, adjusting its approach as needed.
|
||||
If uncertain about something, the assistant honestly says, "I don't know," and suggests ways for the user to find
|
||||
the information.
|
||||
|
||||
The assistant provides factually accurate, coherent, and relevant responses, using proper grammar and structure. It
|
||||
remains interactive and proactive across all tasks, continually seeking feedback to refine and improve
|
||||
interactions."""
|
||||
|
||||
|
||||
# Required fields: "prompt", "ground_truth", "extra_info"
|
||||
# In "extra_info" dict:
|
||||
# (1) Rquired: "single_turn_prompt", which is the specific problem used to inform the user simulator,
|
||||
# (2) Optional: "task_desc" (a short task description),
|
||||
# (3) Optional: other fields for customized reward computation
|
||||
def collapse_example(example: dict[str, Any]) -> dict[str, Any]:
|
||||
if "prompt" not in example:
|
||||
raise ValueError("Missing required 'prompt' field.")
|
||||
|
||||
ground_truth = (
|
||||
example.get("ground_truth") or example.get("single_turn_completion") or example.get("completion") or ""
|
||||
)
|
||||
|
||||
extra_info = {}
|
||||
for k, v in example.items():
|
||||
if k in ("prompt", "ground_truth", "extra_info"):
|
||||
continue
|
||||
extra_info.setdefault(k, v) # keep extra_info values if keys overlap
|
||||
|
||||
# make sure extra_info has the required fields
|
||||
assert "single_turn_prompt" in extra_info, "Missing 'single_turn_prompt' in extra_info."
|
||||
|
||||
# add system prompt as the beginning of the list
|
||||
example["prompt"] = [{"role": "system", "content": SYSTEM_PROMPT}] + example["prompt"]
|
||||
|
||||
extra_info.setdefault("prompt", example["prompt"]) # save the original prompt
|
||||
extra_info.setdefault(
|
||||
"interaction_kwargs",
|
||||
{
|
||||
"name": "collabllm",
|
||||
"single_turn_prompt": extra_info.pop("single_turn_prompt"),
|
||||
"task_desc": extra_info.pop("task_desc", "general ask-for-assistance task"),
|
||||
},
|
||||
)
|
||||
return {
|
||||
"prompt": example["prompt"],
|
||||
"ground_truth": ground_truth,
|
||||
"raw_prompt": example["prompt"], # save the original prompt
|
||||
"extra_info": extra_info,
|
||||
"reward_model": {"style": "rule", "ground_truth": ground_truth},
|
||||
"data_source": "collabllm",
|
||||
"agent_name": "collabllm_agent",
|
||||
"index": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
|
||||
# ---------- IO helpers ----------
|
||||
def save_parquet(ds_split: Dataset, filename: str, out_dir: str) -> None:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
path = os.path.join(out_dir, f"{filename}.parquet")
|
||||
ds_split.to_parquet(path)
|
||||
print(f"[OK] Wrote {filename}.parquet → {path} ({len(ds_split)} rows)")
|
||||
|
||||
|
||||
def maybe_copy_to_hdfs(local_dir: str, hdfs_dir: Optional[str]) -> None:
|
||||
if not hdfs_dir:
|
||||
return
|
||||
try:
|
||||
from verl.utils.hdfs_io import copy, makedirs # type: ignore
|
||||
except Exception as e:
|
||||
print(f"[WARN] Skipping HDFS copy (verl not available): {e}")
|
||||
return
|
||||
makedirs(hdfs_dir)
|
||||
copy(src=local_dir, dst=hdfs_dir)
|
||||
print(f"[OK] Copied {local_dir} → {hdfs_dir}")
|
||||
|
||||
|
||||
# ---------- Main ----------
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument(
|
||||
"--dataset", default="collabllm/collabllm-multiturn-math-hard", help="HF dataset path or local dir/file."
|
||||
)
|
||||
ap.add_argument("--task_desc", default="solving math problems", help="Task description for the dataset.")
|
||||
ap.add_argument("--local_dir", default="~/data/collabllm-math-hard", help="Output directory.")
|
||||
ap.add_argument("--hdfs_dir", default=None, help="Optional HDFS destination (requires verl).")
|
||||
ap.add_argument(
|
||||
"--validation_size", type=float, default=0.1, help="Validation split size (fraction or absolute int)."
|
||||
)
|
||||
ap.add_argument("--seed", type=int, default=42, help="Random seed for splitting.")
|
||||
ap.add_argument("--num_proc", type=int, default=1, help="Parallel workers for map().")
|
||||
ap.add_argument("--dataset_type", default="rl", choices=["rl", "sft"], help="Type of dataset (e.g., 'rl', 'sft').")
|
||||
args = ap.parse_args()
|
||||
|
||||
out_dir = os.path.expanduser(args.local_dir)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
print(f"[INFO] Loading dataset: {args.dataset}")
|
||||
ds_dict = load_dataset(args.dataset)
|
||||
parts = list(ds_dict.values())
|
||||
ds_all: Dataset = parts[0] if len(parts) == 1 else concatenate_datasets(parts)
|
||||
# Dataset({
|
||||
# features: ['prompt', 'completion', 'conv_id', 'score', 'single_turn_prompt',
|
||||
# 'single_turn_completion', 'single_turn_metadata', 'turn_id', 'sessions', 'rewards'],
|
||||
# num_rows: xxx
|
||||
# })
|
||||
|
||||
if args.dataset_type == "rl":
|
||||
# If multiple splits exist, merge them before collapsing/splitting.
|
||||
ds_all = ds_all.map(lambda x: {"task_desc": args.task_desc}, num_proc=args.num_proc)
|
||||
|
||||
print(f"[INFO] Collapsing to formatted fields on {len(ds_all)} rows…")
|
||||
ds_all = ds_all.map(
|
||||
function=collapse_example,
|
||||
remove_columns=ds_all.column_names,
|
||||
num_proc=args.num_proc,
|
||||
)
|
||||
|
||||
def dedup_by_prompt(dataset):
|
||||
seen = set()
|
||||
unique_rows = []
|
||||
for ex in dataset:
|
||||
prompt_key = json.dumps(ex["prompt"], sort_keys=True, ensure_ascii=False)
|
||||
if prompt_key not in seen:
|
||||
seen.add(prompt_key)
|
||||
unique_rows.append(ex)
|
||||
return Dataset.from_list(unique_rows)
|
||||
|
||||
ds_all = dedup_by_prompt(ds_all)
|
||||
|
||||
elif args.dataset_type == "sft":
|
||||
df = ds_all.to_pandas()
|
||||
|
||||
# Sort so that within each conv_id the highest turn_id is first,
|
||||
# and if multiple rows share the same turn_id, the highest score comes first
|
||||
df = df.sort_values(["conv_id", "turn_id", "score"], ascending=[True, False, False])
|
||||
|
||||
# Keep only the top row per conv_id
|
||||
df = df.drop_duplicates(subset="conv_id", keep="first")
|
||||
|
||||
# Back to HF Dataset
|
||||
ds_all = Dataset.from_pandas(df, preserve_index=False)
|
||||
|
||||
# Append assistant response into prompt list
|
||||
def append_completion(example):
|
||||
example["prompt"] = (
|
||||
[{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
+ example["prompt"]
|
||||
+ [{"role": "assistant", "content": example["completion"]}]
|
||||
)
|
||||
return example
|
||||
|
||||
ds_all = ds_all.map(append_completion)
|
||||
|
||||
# Keep only prompt column
|
||||
cols_to_remove = [col for col in ds_all.column_names if col != "prompt"]
|
||||
ds_all = ds_all.remove_columns(cols_to_remove)
|
||||
|
||||
print(f"[INFO] Splitting with validation_size={args.validation_size}, seed={args.seed}")
|
||||
split = ds_all.train_test_split(test_size=args.validation_size, seed=args.seed, shuffle=True)
|
||||
train_ds, val_ds = split["train"], split["test"]
|
||||
print(train_ds, val_ds)
|
||||
|
||||
save_parquet(train_ds, f"{args.dataset_type}_train", out_dir)
|
||||
save_parquet(val_ds, f"{args.dataset_type}_validation", out_dir)
|
||||
|
||||
maybe_copy_to_hdfs(local_dir=out_dir, hdfs_dir=args.hdfs_dir)
|
||||
print(f"[DONE] {args.dataset_type}_train.parquet and {args.dataset_type}_validation.parquet written.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
95
recipe/collabllm/reward_function.py
Normal file
95
recipe/collabllm/reward_function.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
|
||||
import litellm
|
||||
import torch
|
||||
|
||||
|
||||
async def conversation_level_reward_func(
|
||||
data_source, messages, ground_truth, extra_info, metrics, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Async version of conversation-level reward function.
|
||||
|
||||
Apply conversation-level reward function to the future interactions between the user simulator
|
||||
and policy model, which are generated from `verl/interactions/collabllm_interation.py`
|
||||
"""
|
||||
num_retries = kwargs.get("num_retries", 6)
|
||||
|
||||
rewards = {}
|
||||
for metric in metrics:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
metric_file_path = os.path.join(current_dir, f"metrics/{metric}.py")
|
||||
|
||||
if not os.path.exists(metric_file_path):
|
||||
print(f"Error: Metric file '{metric_file_path}' not found. Assigning 0 to metric '{metric}'.")
|
||||
rewards[metric] = 0.0
|
||||
continue
|
||||
|
||||
spec = importlib.util.spec_from_file_location(f"metric_{metric}", metric_file_path)
|
||||
if spec is None:
|
||||
print(f"Error: Could not create spec for metric '{metric}'. Assigning 0 to metric '{metric}'.")
|
||||
rewards[metric] = 0.0
|
||||
continue
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
try:
|
||||
sys.modules[f"metric_{metric}"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
except Exception as e:
|
||||
print(f"Error loading metric module from '{metric_file_path}': {e}. Assigning 0 to metric '{metric}'.")
|
||||
rewards[metric] = 0.0
|
||||
continue
|
||||
|
||||
# Assume each metric file has a compute_score function
|
||||
if not hasattr(module, "compute_score"):
|
||||
print(
|
||||
f"Error: Function 'compute_score' not found in '{metric_file_path}'. Assigning 0 to metric '{metric}'."
|
||||
)
|
||||
rewards[metric] = 0.0
|
||||
continue
|
||||
|
||||
compute_score_fn = module.compute_score
|
||||
|
||||
# Retry mechanism for calling the metric function
|
||||
for attempt in range(num_retries):
|
||||
try:
|
||||
# Call the metric function (await if it's async)
|
||||
if asyncio.iscoroutinefunction(compute_score_fn):
|
||||
rewards[metric] = await compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs)
|
||||
else:
|
||||
rewards[metric] = compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs)
|
||||
break # Success, exit retry loop
|
||||
except Exception as e:
|
||||
if attempt == num_retries - 1: # Last attempt
|
||||
print(
|
||||
f"Error: Failed to compute metric '{metric}' after {num_retries} attempts. "
|
||||
f"Last error: {e}. Assigning 0 to metric '{metric}'."
|
||||
)
|
||||
rewards[metric] = 0.0
|
||||
else:
|
||||
print(f"Attempt {attempt + 1} failed for metric '{metric}': {e}. Retrying...")
|
||||
if isinstance(e, litellm.RateLimitError):
|
||||
await asyncio.sleep(max(2**attempt, 60)) # Exponential backoff
|
||||
|
||||
# Return dict with metric names as keys
|
||||
return {metric: torch.tensor(reward, dtype=torch.float32) for metric, reward in rewards.items()}
|
75
recipe/collabllm/train_rl_collabllm.sh
Normal file
75
recipe/collabllm/train_rl_collabllm.sh
Normal file
@ -0,0 +1,75 @@
|
||||
# Usage: sh recipe/collabllm/train_rl_collabllm.sh <optional resume path>
|
||||
|
||||
set -x
|
||||
|
||||
PROJECT_DIR="$(pwd)"
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
RESUME_PATH="${1:-}"
|
||||
|
||||
if [ -z "$RESUME_PATH" ]; then
|
||||
RESUME_PATH=null
|
||||
fi
|
||||
|
||||
DATASET=math-hard-large
|
||||
PROJECT_DIR="$(pwd)"
|
||||
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
trainer.val_before_train=False \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files=$HOME/data/collabllm-$DATASET/rl_train.parquet \
|
||||
data.val_files=$HOME/data/collabllm-$DATASET/rl_validation.parquet \
|
||||
reward_model.reward_manager=collabllm \
|
||||
+reward_model.reward_kwargs.metric_weights.accuracy=1 \
|
||||
+reward_model.reward_kwargs.metric_weights.interactivity=1 \
|
||||
+reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
|
||||
+reward_model.reward_kwargs.llm_judge_kwargs.model=gpt-4o-mini \
|
||||
+reward_model.reward_kwargs.llm_judge_kwargs.max_tokens=2048 \
|
||||
+reward_model.reward_kwargs.llm_judge_kwargs.temperature=0 \
|
||||
data.train_batch_size=16 \
|
||||
data.max_prompt_length=8196 \
|
||||
data.max_response_length=2048 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.path="Qwen/Qwen2.5-7B-Instruct" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=True \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=True \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.mode=async \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
||||
actor_rollout_ref.rollout.n=8 \
|
||||
actor_rollout_ref.rollout.temperature=1.0 \
|
||||
actor_rollout_ref.rollout.free_cache_engine=True \
|
||||
actor_rollout_ref.rollout.multi_turn.enable=true \
|
||||
actor_rollout_ref.rollout.multi_turn.format=hermes \
|
||||
actor_rollout_ref.rollout.multi_turn.max_user_turns=2 \
|
||||
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3 \
|
||||
actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts=3 \
|
||||
actor_rollout_ref.rollout.trace.token2text=True \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger='["console", "wandb"]' \
|
||||
trainer.project_name=verlxcollabllm \
|
||||
trainer.experiment_name=collabllm-qwen2.5-7B-$DATASET \
|
||||
trainer.nnodes=1 \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
trainer.save_freq=100 \
|
||||
trainer.test_freq=10 \
|
||||
trainer.total_epochs=20 \
|
||||
custom_reward_function.path=recipe/collabllm/reward_function.py \
|
||||
custom_reward_function.name=conversation_level_reward_func \
|
||||
actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/recipe/collabllm/config/collabllm_interaction_config.yaml" \
|
||||
trainer.resume_from_path=$RESUME_PATH
|
32
recipe/collabllm/train_sft_collabllm.sh
Normal file
32
recipe/collabllm/train_sft_collabllm.sh
Normal file
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
if [ "$#" -lt 1 ]; then
|
||||
echo "Usage: sft_train_collabllm.sh [<nproc_per_node> other_configs...]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
nproc_per_node=$1
|
||||
|
||||
# Shift the arguments so $@ refers to the rest
|
||||
shift 1
|
||||
|
||||
DATASET=math-hard-large
|
||||
|
||||
torchrun --nnodes=1 --nproc_per_node=$nproc_per_node \
|
||||
-m verl.trainer.fsdp_sft_trainer \
|
||||
data.train_files=$HOME/data/collabllm-$DATASET/sft_train.parquet \
|
||||
data.val_files=$HOME/data/collabllm-$DATASET/sft_validation.parquet \
|
||||
data.multiturn.enable=true \
|
||||
data.multiturn.messages_key=prompt \
|
||||
optim.lr=1e-6 \
|
||||
data.train_batch_size=64 \
|
||||
data.micro_batch_size_per_gpu=2 \
|
||||
data.max_length=8196 \
|
||||
model.partial_pretrain=Qwen/Qwen2.5-7B-Instruct \
|
||||
trainer.project_name=collabllm-sft-$DATASET \
|
||||
trainer.experiment_name=collabllm-sft-qwen2.5-7B-$DATASET \
|
||||
trainer.logger=console \
|
||||
trainer.total_epochs=3 $@ \
|
||||
ulysses_sequence_parallel_size=1 \
|
||||
use_remove_padding=true $@
|
280
recipe/collabllm/utils.py
Normal file
280
recipe/collabllm/utils.py
Normal file
@ -0,0 +1,280 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
def parse_messages(messages, strip_sys_prompt=True):
|
||||
"""
|
||||
Args:
|
||||
messages: List[dict]
|
||||
List of dictionaries with keys 'role' and 'content'
|
||||
Example: messages = [{'role': 'user', 'content': 'Hello!'},
|
||||
{'role': 'assistant', 'content': 'Hi!'}, ...]
|
||||
"""
|
||||
if messages is None:
|
||||
return ""
|
||||
|
||||
if strip_sys_prompt:
|
||||
messages = strip_system_prompt(messages)
|
||||
|
||||
chat = "\n".join(f"**{m.role.capitalize()}**: {m.content}" for m in messages)
|
||||
|
||||
return chat
|
||||
|
||||
|
||||
def strip_system_prompt(messages):
|
||||
"""
|
||||
Args:
|
||||
messages: List[dict]
|
||||
List of dictionaries with keys 'role' and 'content'
|
||||
Example: messages = [{'role': 'user', 'content': 'Hello!'},
|
||||
{'role': 'assistant', 'content': 'Hi!'}, ...]
|
||||
"""
|
||||
return [msg for msg in messages if msg.role != "system"]
|
||||
|
||||
|
||||
def extract_json(s):
|
||||
def convert_value(value):
|
||||
true_values = {"true": True, "false": False, "null": None}
|
||||
value_lower = value.lower()
|
||||
if value_lower in true_values:
|
||||
return true_values[value_lower]
|
||||
try:
|
||||
if "." in value or "e" in value.lower():
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value # Return as string if not a number
|
||||
|
||||
def parse_number(s, pos):
|
||||
start = pos
|
||||
while pos < len(s) and s[pos] in "-+0123456789.eE":
|
||||
pos += 1
|
||||
num_str = s[start:pos]
|
||||
try:
|
||||
if "." in num_str or "e" in num_str.lower():
|
||||
return float(num_str), pos
|
||||
else:
|
||||
return int(num_str), pos
|
||||
except ValueError:
|
||||
logger.error(f"Invalid number at position {start}: {num_str}")
|
||||
raise
|
||||
|
||||
def skip_whitespace(s, pos):
|
||||
while pos < len(s) and s[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
return pos
|
||||
|
||||
def parse_string(s, pos):
|
||||
quote_char = s[pos]
|
||||
assert quote_char in ('"', "'")
|
||||
pos += 1
|
||||
result = ""
|
||||
while pos < len(s):
|
||||
c = s[pos]
|
||||
if c == "\\":
|
||||
pos += 1
|
||||
if pos >= len(s):
|
||||
raise ValueError("Invalid escape sequence")
|
||||
c = s[pos]
|
||||
escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char}
|
||||
result += escape_sequences.get(c, c)
|
||||
elif c == quote_char:
|
||||
pos += 1
|
||||
# Attempt to convert to a number if possible
|
||||
converted_value = convert_value(result)
|
||||
return converted_value, pos
|
||||
else:
|
||||
result += c
|
||||
pos += 1
|
||||
raise ValueError("Unterminated string")
|
||||
|
||||
def parse_key(s, pos):
|
||||
pos = skip_whitespace(s, pos)
|
||||
if s[pos] in ('"', "'"):
|
||||
key, pos = parse_string(s, pos)
|
||||
return key, pos
|
||||
else:
|
||||
raise ValueError(f"Expected string for key at position {pos}")
|
||||
|
||||
def parse_object(s, pos):
|
||||
obj = {}
|
||||
assert s[pos] == "{"
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
while pos < len(s) and s[pos] != "}":
|
||||
pos = skip_whitespace(s, pos)
|
||||
key, pos = parse_key(s, pos)
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos >= len(s) or s[pos] != ":":
|
||||
raise ValueError(f'Expected ":" at position {pos}')
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
value, pos = parse_value(s, pos)
|
||||
obj[key] = value
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos < len(s) and s[pos] == ",":
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
elif pos < len(s) and s[pos] == "}":
|
||||
break
|
||||
elif pos < len(s) and s[pos] != "}":
|
||||
raise ValueError(f'Expected "," or "}}" at position {pos}')
|
||||
if pos >= len(s) or s[pos] != "}":
|
||||
raise ValueError(f'Expected "}}" at position {pos}')
|
||||
pos += 1
|
||||
return obj, pos
|
||||
|
||||
def parse_array(s, pos):
|
||||
lst = []
|
||||
assert s[pos] == "["
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
while pos < len(s) and s[pos] != "]":
|
||||
value, pos = parse_value(s, pos)
|
||||
lst.append(value)
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos < len(s) and s[pos] == ",":
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
elif pos < len(s) and s[pos] == "]":
|
||||
break
|
||||
elif pos < len(s) and s[pos] != "]":
|
||||
raise ValueError(f'Expected "," or "]" at position {pos}')
|
||||
if pos >= len(s) or s[pos] != "]":
|
||||
raise ValueError(f'Expected "]" at position {pos}')
|
||||
pos += 1
|
||||
return lst, pos
|
||||
|
||||
def parse_triple_quoted_string(s, pos):
|
||||
if s[pos : pos + 3] == "'''":
|
||||
quote_str = "'''"
|
||||
elif s[pos : pos + 3] == '"""':
|
||||
quote_str = '"""'
|
||||
else:
|
||||
raise ValueError(f"Expected triple quotes at position {pos}")
|
||||
pos += 3
|
||||
result = ""
|
||||
while pos < len(s):
|
||||
if s[pos : pos + 3] == quote_str:
|
||||
pos += 3
|
||||
# Attempt to convert to a number if possible
|
||||
converted_value = convert_value(result)
|
||||
return converted_value, pos
|
||||
else:
|
||||
result += s[pos]
|
||||
pos += 1
|
||||
raise ValueError("Unterminated triple-quoted string")
|
||||
|
||||
def parse_value(s, pos):
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos >= len(s):
|
||||
raise ValueError("Unexpected end of input")
|
||||
if s[pos] == "{":
|
||||
return parse_object(s, pos)
|
||||
elif s[pos] == "[":
|
||||
return parse_array(s, pos)
|
||||
elif s[pos : pos + 3] in ("'''", '"""'):
|
||||
return parse_triple_quoted_string(s, pos)
|
||||
elif s[pos] in ('"', "'"):
|
||||
return parse_string(s, pos)
|
||||
elif s[pos : pos + 4].lower() == "true":
|
||||
return True, pos + 4
|
||||
elif s[pos : pos + 5].lower() == "false":
|
||||
return False, pos + 5
|
||||
elif s[pos : pos + 4].lower() == "null":
|
||||
return None, pos + 4
|
||||
elif s[pos] in "-+0123456789.":
|
||||
return parse_number(s, pos)
|
||||
else:
|
||||
raise ValueError(f"Unexpected character at position {pos}: {s[pos]}")
|
||||
|
||||
json_start = s.index("{")
|
||||
json_end = s.rfind("}")
|
||||
s = s[json_start : json_end + 1]
|
||||
|
||||
s = s.strip()
|
||||
result, pos = parse_value(s, 0)
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos != len(s):
|
||||
raise ValueError(f"Unexpected content at position {pos}")
|
||||
return result
|
||||
|
||||
|
||||
def remove_think_block(msg: dict):
|
||||
"""
|
||||
remove <think>.*?</think> from content
|
||||
"""
|
||||
if "content" in msg and isinstance(msg["content"], str):
|
||||
msg["content"] = re.sub(r"<think>.*?</think>", "", msg["content"], flags=re.DOTALL).strip()
|
||||
return msg
|
||||
|
||||
|
||||
def is_valid_messages(msg: dict) -> bool:
|
||||
"""
|
||||
check if is valid messages, including:
|
||||
1. <think> is paried with </think>
|
||||
2. is not empty inside and outside <think>
|
||||
3. is not nested, and at most one <think> block is allowed.
|
||||
4. can not be empty if remove ending "<|im_end|>"
|
||||
"""
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, str):
|
||||
return True
|
||||
|
||||
# Base case: empty or whitespace-only content is invalid.
|
||||
if not content.strip():
|
||||
return False
|
||||
|
||||
num_think_open = content.count("<think>")
|
||||
num_think_close = content.count("</think>")
|
||||
|
||||
# Rule 1: Check for paired tags.
|
||||
if num_think_open != num_think_close:
|
||||
return False
|
||||
|
||||
# Rule 3: Allow at most one think block.
|
||||
if num_think_open > 1:
|
||||
return False
|
||||
|
||||
# Case 1: No <think> blocks.
|
||||
if num_think_open == 0:
|
||||
visible_content = content
|
||||
# Case 2: Exactly one <think> block.
|
||||
else:
|
||||
# Rule 2: Check for empty content inside the think block.
|
||||
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||
if not match or not match.group(1).strip():
|
||||
return False
|
||||
|
||||
# The "visible" content is what's outside the think block.
|
||||
visible_content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
|
||||
|
||||
visible_content = visible_content.strip()
|
||||
|
||||
# Rule 4 & 2 (outside): Check if visible content is empty after handling <|im_end|>.
|
||||
if visible_content.endswith("<|im_end|>"):
|
||||
visible_content = visible_content[: -len("<|im_end|>")]
|
||||
|
||||
if not visible_content.strip():
|
||||
return False
|
||||
|
||||
return True
|
@ -327,27 +327,47 @@ class RewardManagerWorker:
|
||||
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
|
||||
)
|
||||
self.rm_executor = rm_executor
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def compute_score(
|
||||
async def compute_score(
|
||||
self,
|
||||
data: DataProto,
|
||||
) -> dict:
|
||||
"""Compute reward score for agent loop output.
|
||||
|
||||
NOTE: Since `reward_manager.__call__` is blocking function, we run it in thread pool to
|
||||
compute multiple samples in parallel.
|
||||
|
||||
Args:
|
||||
data: reward function input
|
||||
|
||||
Returns:
|
||||
dict: Reward score and reward extra info.
|
||||
"""
|
||||
result = await self.loop.run_in_executor(
|
||||
None,
|
||||
self.reward_wrapper,
|
||||
data,
|
||||
True, # return_dict
|
||||
)
|
||||
|
||||
reward_score = result["reward_tensor"].sum(dim=-1).item()
|
||||
reward_extra_info = {k: v[0] for k, v in result.get("reward_extra_info", {}).items()}
|
||||
return {"reward_score": reward_score, "reward_extra_info": reward_extra_info}
|
||||
|
||||
def reward_wrapper(self, data: DataProto, return_dict=False) -> torch.Tensor:
|
||||
"""Assemble reward functions and reward model into one function and expose it to the event loop
|
||||
Args:
|
||||
return_dict: whether return as dict
|
||||
data: DataProto from compute reward score
|
||||
Returns:
|
||||
torch.Tensor: Reward score tensor.
|
||||
"""
|
||||
if self.rm_executor is not None:
|
||||
res = ray.get(self.rm_executor.submit_task.remote(data))
|
||||
data = data.union(res)
|
||||
|
||||
result = self.reward_manager(data, return_dict=True)
|
||||
reward_score = result["reward_tensor"].sum(dim=-1).item()
|
||||
reward_extra_info = {k: v[0] for k, v in result.get("reward_extra_info", {}).items()}
|
||||
return {"reward_score": reward_score, "reward_extra_info": reward_extra_info}
|
||||
return self.reward_manager(data, return_dict)
|
||||
|
||||
|
||||
@ray.remote
|
||||
@ -597,6 +617,11 @@ class AgentLoopWorker:
|
||||
**{k: np.array([v]) for k, v in kwargs.items()},
|
||||
"__num_turns__": np.array([output.num_turns]),
|
||||
}
|
||||
extra_fields = {}
|
||||
for key, val in output.extra_fields.items():
|
||||
extra_fields[key] = np.array([val], dtype=object)
|
||||
|
||||
non_tensor_batch.update(extra_fields)
|
||||
data = DataProto(
|
||||
batch=batch,
|
||||
non_tensor_batch=non_tensor_batch,
|
||||
|
@ -133,7 +133,6 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
)
|
||||
interaction = self.interaction_map[interaction_name]
|
||||
await interaction.start_interaction(request_id, **interaction_kwargs)
|
||||
|
||||
# Create AgentData instance to encapsulate all state
|
||||
agent_data = AgentData(
|
||||
messages=messages,
|
||||
@ -152,12 +151,10 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
state = await self._handle_pending_state(agent_data, sampling_params)
|
||||
elif state == AgentState.GENERATING:
|
||||
state = await self._handle_generating_state(agent_data, sampling_params)
|
||||
agent_data.assistant_turns += 1
|
||||
elif state == AgentState.PROCESSING_TOOLS:
|
||||
state = await self._handle_processing_tools_state(agent_data)
|
||||
elif state == AgentState.INTERACTING:
|
||||
state = await self._handle_interacting_state(agent_data)
|
||||
agent_data.user_turns += 1
|
||||
else:
|
||||
logger.error(f"Invalid state: {state}")
|
||||
state = AgentState.TERMINATED
|
||||
@ -209,7 +206,9 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
)
|
||||
return AgentState.GENERATING
|
||||
|
||||
async def _handle_generating_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState:
|
||||
async def _handle_generating_state(
|
||||
self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False
|
||||
) -> AgentState:
|
||||
"""Handle the generating state: generate model response and check for tool calls."""
|
||||
add_messages: list[dict[str, Any]] = []
|
||||
|
||||
@ -221,6 +220,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
image_data=agent_data.image_data,
|
||||
)
|
||||
|
||||
agent_data.assistant_turns += 1
|
||||
agent_data.response_ids = output.token_ids
|
||||
agent_data.prompt_ids += agent_data.response_ids
|
||||
agent_data.response_mask += [1] * len(agent_data.response_ids)
|
||||
@ -228,7 +228,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
agent_data.response_logprobs += output.log_probs
|
||||
|
||||
# Check termination conditions
|
||||
if len(agent_data.response_mask) >= self.response_length:
|
||||
if not ignore_termination and len(agent_data.response_mask) >= self.response_length:
|
||||
return AgentState.TERMINATED
|
||||
if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns:
|
||||
return AgentState.TERMINATED
|
||||
@ -241,7 +241,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
# Handle interaction if needed
|
||||
if self.interaction_config_file:
|
||||
assistant_message = await self.loop.run_in_executor(
|
||||
None, lambda: self.tokenizer.decode(agent_data.response_ids)
|
||||
None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True)
|
||||
)
|
||||
add_messages.append({"role": "assistant", "content": assistant_message})
|
||||
agent_data.messages.extend(add_messages)
|
||||
@ -368,8 +368,10 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
) = await agent_data.interaction.generate_response(
|
||||
agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs
|
||||
)
|
||||
agent_data.user_turns += 1
|
||||
|
||||
add_messages: list[dict[str, Any]] = [{"role": "user", "content": interaction_responses}]
|
||||
agent_data.messages.extend(add_messages)
|
||||
|
||||
if reward is not None:
|
||||
agent_data.turn_scores.append(reward)
|
||||
@ -400,6 +402,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
if agent_data.response_logprobs:
|
||||
agent_data.response_logprobs += [0.0] * len(response_ids)
|
||||
|
||||
# double check prompt
|
||||
# Check termination condition
|
||||
if should_terminate_sequence:
|
||||
return AgentState.TERMINATED
|
||||
|
374
verl/interactions/collabllm_interation.py
Normal file
374
verl/interactions/collabllm_interation.py
Normal file
@ -0,0 +1,374 @@
|
||||
# Copyright 2024 CollabLLM Ltd. and/or its affiliates
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from recipe.collabllm.utils import remove_think_block
|
||||
from verl.utils.rollout_trace import rollout_trace_op
|
||||
|
||||
from .base import BaseInteraction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
|
||||
USER_PROMPT_TEMPLATE = """You are role-playing as a human USER interacting with an AI collaborator to complete a specific task. Your goal is to generate realistic, natural responses that a user might give in this scenario.
|
||||
|
||||
## Input Information:
|
||||
You will be provided with:
|
||||
- Task Description: The type of task you are trying to accomplish.
|
||||
- Complete Prompt or Reference Goal: This field may include the complete user request/query or a reference answer to user's request. Use this field to understand the user's intent, requirements, or what would count as a satisfactory outcome.
|
||||
- Chat History: The ongoing conversation between you (as the user) and the AI
|
||||
|
||||
Inputs:
|
||||
<|The Start of Task Description (Not visible to the AI)|>
|
||||
{task_desc}
|
||||
<|The End of Task Description|>
|
||||
|
||||
<|The Start of Complete Prompt or Reference Goal (Not visible to the AI)|>
|
||||
{single_turn_prompt}
|
||||
<|The End of Complete Prompt or Reference Goal|>
|
||||
|
||||
<|The Start of Chat History|>
|
||||
{chat_history}
|
||||
<|The End of Chat History|>
|
||||
|
||||
|
||||
## Guidelines:
|
||||
- Stay in Character: Role-play as a human USER. You are NOT an AI. Maintain a consistent personality throughout the chat.
|
||||
- Minimize Effort: IMPORTANT! As a user, avoid being too detailed in your responses. Provide vague or incomplete demands in the early stages of the conversation to minimize your effort. Let the AI ask for clarification rather than providing everything upfront.
|
||||
- Knowledge Background: Reflect the user's knowledge level in the role-playing. If the user is less knowledgeable about a task, they might not notice incorrect statements. Ask questions that demonstrate your current understanding and areas of confusion.
|
||||
- Occasionally Make Mistakes: Real-world users might misspell words, provide incorrect dates, give wrong information, or ask unclear questions. Simulate this behavior to reflect natural interactions.
|
||||
- Mention Personal Preferences: Include preferences or constraints that might influence your requests or responses. For example, "I prefer short answers," "I need this done quickly," or "I like detailed comments in code."
|
||||
- Goal-Oriented: Keep the chat focused on your intent. Avoid small talk or digressions. Redirect the chat back to the main objective if it starts to stray.
|
||||
|
||||
## Output Format:
|
||||
You should output a JSON object with three entries:
|
||||
- "current_answer" (str): Briefly summerize the AI's current solution to the task.
|
||||
- "thought" (str): Output your thought process as a user deciding what to say next. Consider:
|
||||
1. Have you obtained a satisfactory solution from the AI? If yes, you can terminate this chat.
|
||||
2. If not, what specific part of the problem or solution are you struggling with?
|
||||
3. Has the AI asked you to perform a task or answer a question? If so, how should you approach it?
|
||||
4. Are you noticing any patterns or potential misunderstandings that need clarification?
|
||||
5. If you're stuck, how can you phrase your question to get the most helpful response while demonstrating your current understanding?
|
||||
- "response" (str): Based on your thought process, respond to the AI as the user you are role-playing. Stop immediately when the user's response is completed.
|
||||
|
||||
## Important Notes:
|
||||
- Respond Based on Previous Messages: Your responses should be based on the context of the current chat history. Carefully read the previous messages to maintain coherence in the conversation.
|
||||
- Conversation Flow: If "Current Chat History" is empty, start the conversation from scratch with an initial request. Otherwise, continue based on the existing conversation.
|
||||
- Don't Copy Input Directly: Use the provided information for understanding context only. Avoid copying target queries or any provided information directly in your responses.
|
||||
- Completion Signal: Use "{termination_signal}" as your response when you believe your goal has been solved or if you determine the AI cannot help further.
|
||||
- Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured.
|
||||
|
||||
Remember to stay in character as a user throughout your response, and follow the instructions and guidelines carefully.""" # noqa
|
||||
|
||||
|
||||
class CollabLLMInteraction(BaseInteraction):
|
||||
"""A demo interaction for calculating the reward of CollabLLM.
|
||||
|
||||
- `start_interaction`: start a interaction instance for a trajectory.
|
||||
- `generate_response`: generate the response of the assistant.
|
||||
- `calculate_score`: calculate the score of the interaction.
|
||||
- `finalize_interaction`: finalize the interaction instance.
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
super().__init__(config)
|
||||
_config = copy.deepcopy(config)
|
||||
|
||||
_config.pop("enable_log", None)
|
||||
|
||||
self.name = _config.pop("name")
|
||||
self.user_model = _config.pop("user_model")
|
||||
|
||||
self.termination_signal = _config.pop("termination_signal", TERMINATION_SIGNAL)
|
||||
self.num_retries = _config.pop("num_retries", 3)
|
||||
|
||||
self.user_model_kwargs = _config
|
||||
|
||||
self._instance_dict = {}
|
||||
|
||||
async def start_interaction(
|
||||
self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
if instance_id is None:
|
||||
instance_id = str(uuid4())
|
||||
self._instance_dict[instance_id] = {
|
||||
"response": "",
|
||||
"ground_truth": ground_truth,
|
||||
"reward": 0.0,
|
||||
}
|
||||
self.interaction_kwargs = kwargs
|
||||
assert "single_turn_prompt" in kwargs, "single_turn_prompt is required in interaction_kwargs"
|
||||
return instance_id
|
||||
|
||||
@rollout_trace_op
|
||||
async def generate_response(
|
||||
self, instance_id: str, messages: list[dict[str, Any]], **kwargs
|
||||
) -> tuple[bool, str, float, dict]:
|
||||
assert messages[-1]["role"] in ["system", "assistant"], (
|
||||
"Last message input to the user model must be from system or assistant role"
|
||||
)
|
||||
|
||||
import litellm
|
||||
|
||||
chat_history = self._parse_messages(messages, strip_sys_prompt=True)
|
||||
prompt = USER_PROMPT_TEMPLATE.format(
|
||||
task_desc=self.interaction_kwargs.get("task_desc", "general assistance task"),
|
||||
single_turn_prompt=self.interaction_kwargs["single_turn_prompt"],
|
||||
chat_history=chat_history,
|
||||
termination_signal=self.termination_signal,
|
||||
)
|
||||
response = ""
|
||||
for i in range(self.num_retries):
|
||||
try:
|
||||
full_response = (
|
||||
(
|
||||
await litellm.acompletion(
|
||||
model=self.user_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**self.user_model_kwargs,
|
||||
)
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
except litellm.RateLimitError as e:
|
||||
logger.warning(f"[CollabLLMInteraction] hit RateLimitError: {e}. Retrying...")
|
||||
await asyncio.sleep(max(2**i, 60))
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(f"An unexpected error occurred in CollabLLMAgentLoop: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(full_response, str):
|
||||
full_response = extract_json(full_response)
|
||||
except Exception as e:
|
||||
logger.warning(f"[CollabLLMInteraction] Error extracting JSON: {e}. Retrying...")
|
||||
continue
|
||||
|
||||
if isinstance(full_response, dict):
|
||||
keys = full_response.keys()
|
||||
if {"current_answer", "thought", "response"}.issubset(keys):
|
||||
response = full_response.pop("response")
|
||||
if isinstance(response, str):
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"[CollabLLMInteraction] got an invaild response {response} full_response {full_response}. \
|
||||
Retrying..."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"[CollabLLMInteraction] Keys {keys} do not match expected keys. Retrying...")
|
||||
continue
|
||||
|
||||
self._instance_dict[instance_id]["response"] = response
|
||||
logger.debug(f"[CollabLLMInteraction] User: {response}")
|
||||
should_terminate_sequence = self.termination_signal in response
|
||||
reward = 0.0
|
||||
|
||||
return should_terminate_sequence, response, reward, {}
|
||||
|
||||
async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
|
||||
del self._instance_dict[instance_id]
|
||||
|
||||
def _parse_messages(self, messages, strip_sys_prompt=True):
|
||||
if messages is None:
|
||||
return ""
|
||||
|
||||
if strip_sys_prompt:
|
||||
messages = [msg for msg in messages if msg["role"] != "system"]
|
||||
|
||||
messages = [remove_think_block(msg) for msg in messages]
|
||||
|
||||
chat = "\n".join(f"**{m['role'].capitalize()}**: {m['content']}" for m in messages)
|
||||
|
||||
return chat
|
||||
|
||||
|
||||
def extract_json(s):
|
||||
def convert_value(value):
|
||||
true_values = {"true": True, "false": False, "null": None}
|
||||
value_lower = value.lower()
|
||||
if value_lower in true_values:
|
||||
return true_values[value_lower]
|
||||
try:
|
||||
if "." in value or "e" in value.lower():
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value # Return as string if not a number
|
||||
|
||||
def parse_number(s, pos):
|
||||
start = pos
|
||||
while pos < len(s) and s[pos] in "-+0123456789.eE":
|
||||
pos += 1
|
||||
num_str = s[start:pos]
|
||||
try:
|
||||
if "." in num_str or "e" in num_str.lower():
|
||||
return float(num_str), pos
|
||||
else:
|
||||
return int(num_str), pos
|
||||
except ValueError:
|
||||
logger.error(f"Invalid number at position {start}: {num_str}")
|
||||
raise
|
||||
|
||||
def skip_whitespace(s, pos):
|
||||
while pos < len(s) and s[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
return pos
|
||||
|
||||
def parse_string(s, pos):
|
||||
quote_char = s[pos]
|
||||
assert quote_char in ('"', "'")
|
||||
pos += 1
|
||||
result = ""
|
||||
while pos < len(s):
|
||||
c = s[pos]
|
||||
if c == "\\":
|
||||
pos += 1
|
||||
if pos >= len(s):
|
||||
raise ValueError("Invalid escape sequence")
|
||||
c = s[pos]
|
||||
escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char}
|
||||
result += escape_sequences.get(c, c)
|
||||
elif c == quote_char:
|
||||
pos += 1
|
||||
# Attempt to convert to a number if possible
|
||||
converted_value = convert_value(result)
|
||||
return converted_value, pos
|
||||
else:
|
||||
result += c
|
||||
pos += 1
|
||||
raise ValueError("Unterminated string")
|
||||
|
||||
def parse_key(s, pos):
|
||||
pos = skip_whitespace(s, pos)
|
||||
if s[pos] in ('"', "'"):
|
||||
key, pos = parse_string(s, pos)
|
||||
return key, pos
|
||||
else:
|
||||
raise ValueError(f"Expected string for key at position {pos}")
|
||||
|
||||
def parse_object(s, pos):
|
||||
obj = {}
|
||||
assert s[pos] == "{"
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
while pos < len(s) and s[pos] != "}":
|
||||
pos = skip_whitespace(s, pos)
|
||||
key, pos = parse_key(s, pos)
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos >= len(s) or s[pos] != ":":
|
||||
raise ValueError(f'Expected ":" at position {pos}')
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
value, pos = parse_value(s, pos)
|
||||
obj[key] = value
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos < len(s) and s[pos] == ",":
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
elif pos < len(s) and s[pos] == "}":
|
||||
break
|
||||
elif pos < len(s) and s[pos] != "}":
|
||||
raise ValueError(f'Expected "," or "}}" at position {pos}')
|
||||
if pos >= len(s) or s[pos] != "}":
|
||||
raise ValueError(f'Expected "}}" at position {pos}')
|
||||
pos += 1
|
||||
return obj, pos
|
||||
|
||||
def parse_array(s, pos):
|
||||
lst = []
|
||||
assert s[pos] == "["
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
while pos < len(s) and s[pos] != "]":
|
||||
value, pos = parse_value(s, pos)
|
||||
lst.append(value)
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos < len(s) and s[pos] == ",":
|
||||
pos += 1
|
||||
pos = skip_whitespace(s, pos)
|
||||
elif pos < len(s) and s[pos] == "]":
|
||||
break
|
||||
elif pos < len(s) and s[pos] != "]":
|
||||
raise ValueError(f'Expected "," or "]" at position {pos}')
|
||||
if pos >= len(s) or s[pos] != "]":
|
||||
raise ValueError(f'Expected "]" at position {pos}')
|
||||
pos += 1
|
||||
return lst, pos
|
||||
|
||||
def parse_triple_quoted_string(s, pos):
|
||||
if s[pos : pos + 3] == "'''":
|
||||
quote_str = "'''"
|
||||
elif s[pos : pos + 3] == '"""':
|
||||
quote_str = '"""'
|
||||
else:
|
||||
raise ValueError(f"Expected triple quotes at position {pos}")
|
||||
pos += 3
|
||||
result = ""
|
||||
while pos < len(s):
|
||||
if s[pos : pos + 3] == quote_str:
|
||||
pos += 3
|
||||
# Attempt to convert to a number if possible
|
||||
converted_value = convert_value(result)
|
||||
return converted_value, pos
|
||||
else:
|
||||
result += s[pos]
|
||||
pos += 1
|
||||
raise ValueError("Unterminated triple-quoted string")
|
||||
|
||||
def parse_value(s, pos):
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos >= len(s):
|
||||
raise ValueError("Unexpected end of input")
|
||||
if s[pos] == "{":
|
||||
return parse_object(s, pos)
|
||||
elif s[pos] == "[":
|
||||
return parse_array(s, pos)
|
||||
elif s[pos : pos + 3] in ("'''", '"""'):
|
||||
return parse_triple_quoted_string(s, pos)
|
||||
elif s[pos] in ('"', "'"):
|
||||
return parse_string(s, pos)
|
||||
elif s[pos : pos + 4].lower() == "true":
|
||||
return True, pos + 4
|
||||
elif s[pos : pos + 5].lower() == "false":
|
||||
return False, pos + 5
|
||||
elif s[pos : pos + 4].lower() == "null":
|
||||
return None, pos + 4
|
||||
elif s[pos] in "-+0123456789.":
|
||||
return parse_number(s, pos)
|
||||
else:
|
||||
raise ValueError(f"Unexpected character at position {pos}: {s[pos]}")
|
||||
|
||||
json_start = s.index("{")
|
||||
json_end = s.rfind("}")
|
||||
s = s[json_start : json_end + 1]
|
||||
|
||||
s = s.strip()
|
||||
result, pos = parse_value(s, 0)
|
||||
pos = skip_whitespace(s, pos)
|
||||
if pos != len(s):
|
||||
raise ValueError(f"Unexpected content at position {pos}")
|
||||
return result
|
@ -226,6 +226,7 @@ actor_rollout_ref:
|
||||
use_inference_chat_template: false
|
||||
tokenization_sanity_check_mode: strict
|
||||
format: hermes
|
||||
num_repeat_rollouts: null
|
||||
calculate_log_probs: false
|
||||
agent:
|
||||
_target_: verl.workers.config.AgentLoopConfig
|
||||
|
@ -213,6 +213,7 @@ actor_rollout_ref:
|
||||
use_inference_chat_template: false
|
||||
tokenization_sanity_check_mode: strict
|
||||
format: hermes
|
||||
num_repeat_rollouts: null
|
||||
calculate_log_probs: false
|
||||
agent:
|
||||
_target_: verl.workers.config.AgentLoopConfig
|
||||
|
@ -180,6 +180,9 @@ multi_turn:
|
||||
# Format of the multi-turn interaction. Options: hermes, llama3_json, ...
|
||||
format: hermes
|
||||
|
||||
# Number of repeat rollouts for each interaction
|
||||
num_repeat_rollouts: null
|
||||
|
||||
# support logging rollout prob for debugging purpose
|
||||
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
|
||||
calculate_log_probs: False
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
|
||||
Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain.
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -312,6 +312,7 @@ class TaskRunner:
|
||||
)
|
||||
# Initialize the workers of the trainer.
|
||||
trainer.init_workers()
|
||||
|
||||
# Start the training process.
|
||||
trainer.fit()
|
||||
|
||||
@ -349,7 +350,6 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr
|
||||
|
||||
dataset_cls = DynamicGenDataset
|
||||
print("Using DynamicGenDataset for data generation.")
|
||||
|
||||
else:
|
||||
# Use the default RLHFDataset class if no custom class is specified
|
||||
dataset_cls = RLHFDataset
|
||||
|
@ -229,6 +229,7 @@ def compute_advantage(
|
||||
elif adv_estimator == AdvantageEstimator.GRPO:
|
||||
# Initialize the mask for GRPO calculation
|
||||
grpo_calculation_mask = data.batch["response_mask"]
|
||||
|
||||
# Call compute_grpo_outcome_advantage with parameters matching its definition
|
||||
advantages, returns = core_algos.compute_grpo_outcome_advantage(
|
||||
token_level_rewards=data.batch["token_level_rewards"],
|
||||
@ -981,7 +982,6 @@ class RayPPOTrainer:
|
||||
if self.config.global_profiler.profile_continuous_steps
|
||||
else curr_step_profile
|
||||
)
|
||||
|
||||
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
||||
|
||||
# add uid to batch
|
||||
@ -996,7 +996,6 @@ class RayPPOTrainer:
|
||||
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
|
||||
is_last_step = self.global_steps >= self.total_training_steps
|
||||
|
||||
with marked_timer("step", timing_raw):
|
||||
# generate a batch
|
||||
with marked_timer("gen", timing_raw, color="red"):
|
||||
@ -1004,6 +1003,7 @@ class RayPPOTrainer:
|
||||
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
|
||||
else:
|
||||
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
|
||||
|
||||
timing_raw.update(gen_batch_output.meta_info["timing"])
|
||||
gen_batch_output.meta_info.pop("timing", None)
|
||||
|
||||
@ -1027,7 +1027,6 @@ class RayPPOTrainer:
|
||||
batch.batch["reward_baselines"] = reward_baseline_tensor
|
||||
|
||||
del gen_baseline_batch, gen_baseline_output
|
||||
|
||||
# repeat to align with repeated responses in rollout
|
||||
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
batch = batch.union(gen_batch_output)
|
||||
@ -1109,7 +1108,6 @@ class RayPPOTrainer:
|
||||
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
||||
"norm_adv_by_std_in_grpo", True
|
||||
) # GRPO adv normalization factor
|
||||
|
@ -55,6 +55,7 @@ class MultiTurnConfig(BaseConfig):
|
||||
use_inference_chat_template: bool = False
|
||||
tokenization_sanity_check_mode: str = "strict"
|
||||
format: str = "hermes"
|
||||
num_repeat_rollouts: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from .registry import get_reward_manager_cls, register # noqa: I001
|
||||
from .batch import BatchRewardManager
|
||||
from .collabllm import CollabLLMRewardManager
|
||||
from .dapo import DAPORewardManager
|
||||
from .naive import NaiveRewardManager
|
||||
from .prime import PrimeRewardManager
|
||||
@ -21,6 +22,7 @@ from .prime import PrimeRewardManager
|
||||
# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies
|
||||
__all__ = [
|
||||
"BatchRewardManager",
|
||||
"CollabLLMRewardManager",
|
||||
"DAPORewardManager",
|
||||
"NaiveRewardManager",
|
||||
"PrimeRewardManager",
|
||||
|
152
verl/workers/reward_manager/collabllm.py
Normal file
152
verl/workers/reward_manager/collabllm.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.reward_score import default_compute_score
|
||||
from verl.workers.reward_manager import register
|
||||
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||
|
||||
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
|
||||
|
||||
|
||||
@register("collabllm")
|
||||
class CollabLLMRewardManager(AbstractRewardManager):
|
||||
"""
|
||||
The Reward Manager used in https://github.com/Wuyxin/collabllm/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_examine: int,
|
||||
metric_weights: dict,
|
||||
llm_judge_kwargs: dict,
|
||||
reward_fn_key: str = "data_source",
|
||||
compute_score: Optional[Callable] = None,
|
||||
normalize_by_data_source=False,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
|
||||
self.compute_score = compute_score or default_compute_score
|
||||
self.reward_fn_key = reward_fn_key
|
||||
|
||||
self.metric_weights = metric_weights
|
||||
self.llm_judge_kwargs = llm_judge_kwargs
|
||||
self.normalize_by_data_source = normalize_by_data_source
|
||||
|
||||
self.metrics = list(self.metric_weights.keys())
|
||||
# force CollabLLMAgentLoop to be registered
|
||||
from recipe.collabllm.collabllm_agent_loop import CollabLLMAgentLoop # noqa
|
||||
|
||||
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||
if "rm_scores" in data.batch.keys():
|
||||
if return_dict:
|
||||
return {"reward_tensor": data.batch["rm_scores"]}
|
||||
else:
|
||||
return data.batch["rm_scores"]
|
||||
# Use thread-compatible async loop management instead of asyncio.run()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(self._compute_rewards_async(data, return_dict))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||
# batched scoring
|
||||
prompt_ids = data.batch["prompts"]
|
||||
prompt_length = prompt_ids.shape[-1]
|
||||
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
|
||||
|
||||
data_source = data.non_tensor_batch["data_source"]
|
||||
ground_truth = data.non_tensor_batch["ground_truth"]
|
||||
extra_info = data.non_tensor_batch["extra_info"]
|
||||
message_lst = data.non_tensor_batch["messages"]
|
||||
|
||||
# batch the messages into multiple
|
||||
num_repeat_rollouts = len(message_lst[0]["messages"])
|
||||
batch_size = len(data_source)
|
||||
|
||||
grouped_messages = [
|
||||
[message_lst[i]["messages"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts)
|
||||
]
|
||||
|
||||
# Flatten lists for all batch items across all rollouts
|
||||
flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||
flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||
flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||
flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||
|
||||
if num_repeat_rollouts > 0:
|
||||
tasks = [
|
||||
self.compute_score(
|
||||
flattened_data_sources[i],
|
||||
flattened_messages[i],
|
||||
flattened_ground_truths[i],
|
||||
flattened_extra_infos[i],
|
||||
self.metrics,
|
||||
**self.llm_judge_kwargs,
|
||||
)
|
||||
for i in range(len(flattened_data_sources))
|
||||
]
|
||||
score_dicts = await asyncio.gather(*tasks)
|
||||
|
||||
# Aggregate scores for each metric across repeated rollouts
|
||||
scores_by_metrics = {
|
||||
metric: torch.stack([score_dict[metric] for score_dict in score_dicts])
|
||||
.view(num_repeat_rollouts, -1)
|
||||
.sum(dim=0)
|
||||
for metric in self.metrics
|
||||
}
|
||||
|
||||
# Apply metric-specific weights
|
||||
weighted_scores_by_metrics = {
|
||||
metric: torch.clamp(
|
||||
scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts,
|
||||
min=-1.0,
|
||||
max=1.0,
|
||||
)
|
||||
for metric in self.metrics
|
||||
}
|
||||
# Compute mean of weighted scores for each metric
|
||||
mean_weighted_scores_by_metrics = {
|
||||
metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics
|
||||
}
|
||||
|
||||
# Combine weighted scores from all metrics into a single tensor
|
||||
scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0)
|
||||
else:
|
||||
score_dicts = []
|
||||
scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device)
|
||||
mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics}
|
||||
|
||||
print("Scores:", scores, mean_weighted_scores_by_metrics)
|
||||
|
||||
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
|
||||
|
||||
for i in range(len(data)):
|
||||
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
|
||||
|
||||
if return_dict:
|
||||
return {"reward_tensor": reward_tensor}
|
||||
else:
|
||||
return reward_tensor
|
@ -270,6 +270,7 @@ class SGLangRollout(BaseRollout):
|
||||
self._function_call_parser,
|
||||
) = self._initialize_tools(config, processing_class)
|
||||
self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config)
|
||||
|
||||
# If turn on `free_cache_engine`, SGLang engine's KV cache
|
||||
# will be freed after each `generate_sequences` call.
|
||||
logger.info(
|
||||
@ -287,7 +288,6 @@ class SGLangRollout(BaseRollout):
|
||||
self._init_sampling_params(**kwargs)
|
||||
|
||||
self.processing_class = processing_class
|
||||
|
||||
try:
|
||||
# This is when processing_class is a tokenizer
|
||||
self.pad_token_id = self.processing_class.pad_token_id
|
||||
@ -964,6 +964,9 @@ class SGLangRollout(BaseRollout):
|
||||
):
|
||||
_req.state = AsyncRolloutRequestStateEnum.INTERACTING
|
||||
else:
|
||||
# Add ending condition
|
||||
finish_reason_type = FinishReasonTypeEnum.STOP
|
||||
_req.state = AsyncRolloutRequestStateEnum.COMPLETED
|
||||
break
|
||||
elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING:
|
||||
user_turns += 1
|
||||
@ -980,11 +983,17 @@ class SGLangRollout(BaseRollout):
|
||||
)
|
||||
|
||||
interaction = self.interaction_map[interaction_name]
|
||||
|
||||
should_terminate_sequence, content, reward, metrics = await interaction.generate_response(
|
||||
_req.request_id, messages, **_req.interaction_kwargs
|
||||
)
|
||||
user_turn_rewards.append(reward)
|
||||
if should_terminate_sequence:
|
||||
# Add turn check
|
||||
if (
|
||||
should_terminate_sequence
|
||||
or user_turns > self.config.multi_turn.max_user_turns
|
||||
or current_turns > self.config.multi_turn.max_assistant_turns
|
||||
):
|
||||
finish_reason_type = FinishReasonTypeEnum.STOP
|
||||
_req.state = AsyncRolloutRequestStateEnum.COMPLETED
|
||||
break
|
||||
@ -1013,6 +1022,7 @@ class SGLangRollout(BaseRollout):
|
||||
tool_reward_scores = dict(tool_reward_scores)
|
||||
all_rewards = {**tool_reward_scores, **{"user_turn_rewards": user_turn_rewards}}
|
||||
_req.finalize(self.processing_class, all_rewards, finish_reason_type)
|
||||
|
||||
if self.config.calculate_log_probs:
|
||||
debug_sampling_params = {**self.sampling_params}
|
||||
debug_sampling_params["max_new_tokens"] = 0
|
||||
|
Reference in New Issue
Block a user