[recipe] feat: CollabLLM integration for multiturn training (#3574)

### What does this PR do?

This PR add [CollabLLM](https://aka.ms/CollabLLM) as a training recipe.
The added components include
- A customized `CollabLLMRewardManager` inheriting from
`AbstractRewardManager` to compute multiturn-aware rewards.
- A customized `CollabLLMAgentLoop` inheriting from `AgentLoop` to
sample future conversations with simulated users, which imports
`CollabLLMInteraction` from `verl/interactions/collabllm_interation.py`.

### Checklist Before Starting

- [X] Search for similar PRs. Paste at least one query link here: ...
- [X] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

The training rewards when running `train_rl_collabllm.sh` is increasing
in a relatively stable manner (on 8xH200):
<img width="964" height="480" alt="9baeb0700e3fa6a56596e14a54bc1049"
src="https://github.com/user-attachments/assets/53a810d8-1dd7-4145-bb28-4e475e9d7d9d"
/>

Validation reward:
<img width="974" height="538" alt="39364fd10523b0fde13d48645809f5e3"
src="https://github.com/user-attachments/assets/c34fe9e7-3d83-4132-8e1a-67e82c221d09"
/>

#### Samples of model generation
After training, when user asks generic questions with missing
information, the model learns to ask for clarification
<img width="1213" height="562" alt="c8e0ab31948a48ca396c7eccddd13673"
src="https://github.com/user-attachments/assets/ae41cd77-3c77-4402-b9d3-21993b046a18"
/>
and give suggestions:
<img width="1534" height="190" alt="7adb7d33eb9120d337c2a249c6a2dd22"
src="https://github.com/user-attachments/assets/84e1d8c1-f954-403f-b931-bce45cff1612"
/>

(In contrast, with the same prompt, **GPT-5** doesn't ask for any
clarification:)
<img width="1754" height="1126" alt="be8d8577584c0b2356cb352d6f294205"
src="https://github.com/user-attachments/assets/9b734848-9ed0-4496-af11-68bb8f8d8e08"
/>


### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# No change on the existing APIs
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

Changes:
- Main files under `recipe/collabllm`
- Registered `CollabLLMRewardManager` in
`workers/reward_manager/collabllm.py`
- Added `CollabLLMInteraction` in
`verl/interactions/collabllm_interation.py`

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [X] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [X] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [X] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs). Added
to `verl/docs/algo/collabllm.md`.
- [X] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: The scripts
`train_rl_collabllm.sh` and `train_sft_collabllm.sh` are tested multiple
times.
- [X] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: Chen Haiquan <chenhaiquan@bytedance.com>
This commit is contained in:
Shirley Wu
2025-09-24 18:53:39 -07:00
committed by GitHub
parent ba8555120a
commit 25d78fa913
29 changed files with 2148 additions and 31 deletions

105
docs/algo/collabllm.md Normal file
View File

@ -0,0 +1,105 @@
# Recipe: CollabLLM
Last updated: 09/22/2025.
> Open-Source Algorithm Implementation & Expriement Running: [Haiquan Chen](https://github.com/chenhaiq), [Shirley Wu](https://github.com/Wuyxin)
🏠 [Homepage](https://aka.ms/CollabLLM) | 📝 [Paper](https://arxiv.org/pdf/2502.00640) | 🤗 [Datasets & Models](https://huggingface.co/collabllm) | ⭐️ [Original Implementation](https://github.com/Wuyxin/collabllm)
`verl` provides a recipe for the Outstanding Paper at ICML 2025, **"CollabLLM: From Passive Responders to Active Collaborators"**. [CollabLLM](https://aka.ms/CollabLLM) is a unified fine-tuning framework that optimizes LLMs for effective and efficient multiturn collaboration with users.
**Core Idea:** Models are rewarded based on how well their responses enable effective *future* collaboration with users.
Paper Authors: [Shirley Wu](https://cs.stanford.edu/~shirwu/), [Michel Galley](https://www.microsoft.com/en-us/research/people/mgalley/), Baolin Peng, Hao Cheng, Gavin Li, Yao Dou, Weixin Cai, [James Zou](https://www.james-zou.com/), [Jure Leskovec](https://cs.stanford.edu/people/jure/), [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/)
---
## Quick Start
### 0. Environment
Make sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).
### 1. Prepare Your Dataset
First, process your dataset using the provided script (see example commands and usage in `process_dataset.py`):
```bash
python process_dataset.py --dataset <> ... --dataset_type <sft or rl>
```
**Requirements:**
- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)
- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)
- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository
### 2. Train Your Model
**(Optional) For Supervised Fine-Tuning (SFT):**
```bash
bash train_sft_collabllm.sh
```
**For Reinforcement Learning (RL):**
```bash
bash train_rl_collabllm.sh
```
The RL script shows an example to train CollabLLM on `math-hard-large`.
- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`.
- The Multiturn-aware Reward is aggregated from these three conversational-level rewards:
```
+reward_model.reward_kwargs.metric_weights.accuracy=1 \
+reward_model.reward_kwargs.metric_weights.interactivity=1 \
+reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
```
You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via
```
+reward_model.reward_kwargs.metric_weights.bleu_score=1
```
which will instead apply bleu score on the sampled future conversations.
## Algorithm
| Step | Name | Description |
|------|-------------------------------|-----------------------------------------------------------------------------|
| 1 | Model response generation | The model generates multiple responses for each prompt in a batch. |
| 2 | Collaborative simulation | A user simulator (e.g., GPT or Claude) samples `num_repeat_rollouts` conversations for up to `max_user_turns` additional turns. |
| 3 | Compute Multiturn-aware Reward | Customized conversational reward functions are applied to the sampled conversations. Rewards are aggregated, then averaged across rollouts. |
| 4 | Update model | The model weights are updated using the computed multiturn-aware rewards. |
---
## Configuration
The primary configuration is managed through the launch script `train_rl_collabllm.sh` and the YAML file `recipe/collabllm/config/collabllm_interaction_config.yaml`. Key configuration sections:
| Section | Key Parameters / Notes |
|----------------------|-----------------------------------------------------------------------------------------|
| `data` | Paths to training/validation files, batch sizes, sequence lengths. |
| `actor_rollout_ref` (common) | Base model path (used for actor + initial reference), FSDP settings, optimization (LR, scheduler). |
| `actor_rollout_ref` (CollabLLM-specific) | Hyperparameters under `actor_rollout_ref.rollout.multi_turn`: `max_user_turns`, `max_assistant_turns`, `num_repeat_rollouts`. |
| `interaction` | Defined in `collabllm_interaction_config.yaml`. Specifies user simulator and hyperparameters. Requires exported API keys. |
| `reward_model` | Manager set to `collabllm` by default. Modify `reward_model.reward_kwargs.metric_weights` for conversational rewards and weights. LLM Judge hyperparameters (e.g., `model`, `temperature`) go under `reward_model.reward_kwargs.llm_judge_kwargs`. |
| `algorithm` | GRPO-specific hyperparameters such as `actor_rollout_ref.rollout.n`. |
| `trainer` | Distributed training (nodes, GPUs per node), logging (WandB), checkpointing frequency. |
---
## Key Files
| File Path | Purpose |
|-----------|---------|
| `recipe/collabllm/collabllm_agent_loop.py` | Main logic to sample future conversations, using `CollabLLMInteraction` from `verl/interactions/collabllm_interaction.py`. |
| `verl/workers/reward_manager/collabllm.py` | Computes rewards for future conversations, leveraging `recipe/collabllm/reward_function.py` to apply each metric. |
---
## Acknowledgement
We sincerely thank the `verl` community and advisors for their contributions and guidance!

View File

@ -70,6 +70,7 @@ verl is fast with:
algo/ppo.md
algo/grpo.md
algo/collabllm.md
algo/dapo.md
algo/spin.md
algo/sppo.md

View File

@ -1,23 +1,22 @@
set -x
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
trainer.val_before_train=False \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.train_batch_size=16 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.shuffle=False \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.model.use_shm=True \
actor_rollout_ref.model.lora_rank=64 \
actor_rollout_ref.model.lora_alpha=32 \
actor_rollout_ref.actor.optim.lr=3e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
@ -40,8 +39,13 @@ python3 -m verl.trainer.main_ppo \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2.5_3b_grpo_lora' \
trainer.n_gpus_per_node=8 \
trainer.n_gpus_per_node=2 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
# actor_rollout_ref.actor.ppo_mini_batch_size=256 \
# data.train_batch_size=1024 \
# trainer.n_gpus_per_node=8 \
# actor_rollout_ref.model.use_shm=True \

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=4
NOW=$(date +%Y%m%d)
export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-0.5b-${NOW}
export WANDB_PROJECT=${WANDB_DIR}
@ -7,7 +7,7 @@ export WANDB_EXP=0.5b-${NOW}
MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct
set -x
nproc_per_gpu=116
nproc_per_gpu=1
nnodes=1
ngpu_per_node=1
total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))
@ -15,8 +15,9 @@ mini_batch_size=$(( total_procs ))
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=data/gsm8k/train.parquet \
data.val_files=data/gsm8k/test.parquet \
trainer.val_before_train=False \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=${total_procs} \
data.val_batch_size=${total_procs} \
data.max_prompt_length=512 \
@ -25,7 +26,6 @@ python3 -m verl.trainer.main_ppo \
data.truncation='error' \
data.shuffle=False \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.model.use_shm=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.model.lora_rank=32 \
actor_rollout_ref.model.lora_alpha=32 \
@ -33,7 +33,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.optim.lr=3e-5 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \
actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${mini_batch_size} \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
@ -45,7 +45,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.n=1 \
actor_rollout_ref.rollout.max_num_seqs=512 \
actor_rollout_ref.rollout.max_model_len=1536 \
actor_rollout_ref.rollout.max_num_batched_tokens=1536 \

View File

@ -0,0 +1,74 @@
# CollabLLM
This repository implements [CollabLLM](https://arxiv.org/pdf/2502.00640) (ICML 2025) using the verl framework. For the original implementation, see the [CollabLLM repository](https://github.com/Wuyxin/collabllm).
CollabLLM is a method for training language models to collaborate effectively in multi-turn conversations. This implementation adapts the original imlpementation to work with the Verl training framework.
## Quick start
### 0. Environment
Make sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).
### 1. Prepare Your Dataset
First, process your dataset using the provided script:
```bash
python process_dataset.py --dataset <> ... --dataset_type <sft or rl>
```
**Requirements:**
- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)
- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)
- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository
*Note: Check `process_dataset.py` for example commands and usage.*
### 2. Train Your Model
**(Optional) For Supervised Fine-Tuning (SFT):**
```bash
bash train_sft_collabllm.sh
```
**For Reinforcement Learning (RL):**
```bash
bash train_rl_collabllm.sh
```
The RL script shows an example to train CollabLLM on `math-hard-large`.
- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`.
- The Multiturn-aware Reward is aggregated from these three conversational-level rewards:
```
+reward_model.reward_kwargs.metric_weights.accuracy=1 \
+reward_model.reward_kwargs.metric_weights.interactivity=1 \
+reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
```
You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via
```
+reward_model.reward_kwargs.metric_weights.bleu_score=1
```
which will instead apply bleu score on the sampled future conversations.
## Configuration
Read [doc](https://verl.readthedocs.io/en/latest/) for detailed configurations.
## Citation
If you find CollabLLM useful in your research, please cite the following:
```bibtex
@inproceedings{collabllm2025,
title={CollabLLM: From Passive Responders to Active Collaborators},
author={Shirley Wu and Michel Galley and Baolin Peng and Hao Cheng and
Gavin Li and Yao Dou and Weixin Cai and James Zou and
Jure Leskovec and Jianfeng Gao},
booktitle={International Conference on Machine Learning (ICML)},
year={2025}
}
```

View File

@ -0,0 +1,139 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from copy import deepcopy
from typing import Any
from uuid import uuid4
from recipe.collabllm.utils import is_valid_messages
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register
from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop
from verl.utils.rollout_trace import rollout_trace_op
from verl.workers.rollout.schemas import Message
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@register("collabllm_agent")
class CollabLLMAgentLoop(ToolAgentLoop):
@rollout_trace_op
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
image_data = deepcopy(kwargs.get("multi_modal_data", {}).get("image", None))
metrics = {}
request_id = uuid4().hex
tools_kwargs = kwargs.get("tools_kwargs", {})
# Initialize interaction if needed
interaction = None
interaction_kwargs = {}
if self.interaction_config_file:
interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"]
if "name" not in interaction_kwargs:
raise ValueError("'name' key is required in interaction_kwargs")
interaction_name = interaction_kwargs["name"]
if interaction_name not in self.interaction_map:
raise ValueError(
f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: "
f"{list(self.interaction_map.keys())}"
)
interaction = self.interaction_map[interaction_name]
await interaction.start_interaction(request_id, **interaction_kwargs)
# Create AgentData instance to encapsulate all state
agent_data = AgentData(
messages=messages,
image_data=image_data,
metrics=metrics,
request_id=request_id,
tools_kwargs=tools_kwargs,
interaction=interaction,
interaction_kwargs=interaction_kwargs,
)
# for collabllm, firstly generate model reponses
await self._handle_pending_state(agent_data, sampling_params)
status = await self._handle_generating_state(agent_data, sampling_params)
if status == AgentState.TERMINATED:
# tell reward manager to score -1 and skip future interaction
# to avoid reward hacking with incompleted message
num_repeats = 0
else:
# then, collect interaction rollouts
num_repeats = self.config.actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts
interaction_requests = [deepcopy(agent_data) for _ in range(num_repeats)]
# messages are only used in collabllm reward manager
messages_lst = []
for _agent_data in interaction_requests:
if not is_valid_messages(_agent_data.messages[-1]):
break
prev_msg_len = len(_agent_data.messages)
await self.run_agent_data_loop(_agent_data, sampling_params, AgentState.INTERACTING)
messages_lst.append([Message(**msg) for msg in _agent_data.messages])
if interaction.config.get("enable_log"):
print(f"Assistant: ...{messages_lst[-1][prev_msg_len - 1].content[-100:]}")
print(f"User: {messages_lst[-1][prev_msg_len].content[:100]}...")
# Finalize output
response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]
prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]
multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {}
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
response_mask=agent_data.response_mask[: self.response_length],
multi_modal_data=multi_modal_data,
response_logprobs=agent_data.response_logprobs[: self.response_length]
if agent_data.response_logprobs
else None,
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
metrics=agent_data.metrics,
extra_fields={
"turn_scores": agent_data.turn_scores,
"messages": {"messages": messages_lst}, # compatiable with sglang interaction
},
)
return output
async def run_agent_data_loop(self, agent_data: AgentData, sampling_params: dict[str, Any], state: AgentState):
"""
Run the agent data loop to process the agent data.
Args:
agent_data (AgentData): The agent data to process.
sampling_params (dict[str, Any]): The sampling parameters.
state (AgentState, optional): The initial state of the agent. Defaults to None.
"""
while state != AgentState.TERMINATED:
if state == AgentState.PENDING:
state = await self._handle_pending_state(agent_data, sampling_params)
elif state == AgentState.GENERATING:
state = await self._handle_generating_state(agent_data, sampling_params)
elif state == AgentState.PROCESSING_TOOLS:
state = await self._handle_processing_tools_state(agent_data)
elif state == AgentState.INTERACTING:
state = await self._handle_interacting_state(agent_data)
else:
logger.error(f"Invalid state: {state}")
state = AgentState.TERMINATED

View File

@ -0,0 +1,10 @@
interaction:
- name: "collabllm"
class_name: "verl.interactions.collabllm_interation.CollabLLMInteraction"
config: {
"user_model": "gpt-4o-mini",
"num_retries": 3,
"max_tokens": 512,
"temperature": 1.0,
"enable_log": True
}

View File

@ -0,0 +1,104 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from recipe.collabllm.utils import extract_json, parse_messages
ACCURACY_PROMPT = '''You are a helpful and meticulous evaluator. Your task is to \
evaluate the *accuracy* of an AI model's answer to a target question. \
You will be given the target question, the ground truth answer, and the conversation between the AI and the user.
Provided Information:
<|The Start of Target Question and Ground Truth Answer|>
Target Question: {single_turn_prompt}
Ground Truth Answer: {ground_truth}
<|The End of Target Question and Ground Truth Answer|>
<|The Start of The Conversation|>
{chat_history}
<|The End of The Conversation|>
You should determine whether the model's final response to the target question is \
factually correct and consistent with the provided ground truth.
Rating criteria (binary):
• 1 = Correct — the response matches the ground truth.
• 0 = Incorrect — the response contradicts or misses the ground truth.
Output format (JSON):
{{
"thought": "<your reasoning here>",
"accuracy": <0 or 1>
}}
Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \
Use " or """ to wrap up the thought and use single quotes inside the "thought" field to avoid JSON escape issues.
Your evaluation:
'''
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
# Check if litellm is available, fallback to openai if not
try:
import litellm
use_litellm = True
except ImportError:
# litellm not found, falling back to openai
import openai
use_litellm = False
chat_history = parse_messages(messages, strip_sys_prompt=True)
prompt = ACCURACY_PROMPT.format(
single_turn_prompt=extra_info["interaction_kwargs"]["single_turn_prompt"],
ground_truth=ground_truth,
chat_history=chat_history,
)
if use_litellm:
full_response = (
(
await litellm.acompletion(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
else:
client = openai.AsyncOpenAI() # Assumes API key is set in environment
full_response = (
(
await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
full_response = extract_json(full_response)
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
assert {"accuracy", "thought"}.issubset(full_response.keys()), (
f"Expected keys not found from {full_response.keys()}"
)
accuracy = full_response.pop("accuracy")
return float(accuracy)

View File

@ -0,0 +1,116 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nltk.translate.bleu_score import sentence_bleu
from recipe.collabllm.utils import extract_json, parse_messages
EXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \
Your task is to extract the final and complete version of a document that was generated during \
a multiturn conversation between a user and a chat assistant. \
The extracted content should reflect the final and comprehensive response provided by the assistant \
based on the users request.
You will be provided with the conversation:
<|The Start of The Conversation|>
{chat_history}
<|The End of The Conversation|>
Instructions for Extraction:
1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts \
of the content provided by the assistant. This may include:
- Different sections of text (e.g., an essay, report, or article).
2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \
ensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \
output that incorporates all modifications and additions made during the conversation. For example, if the assistant \
writes an introducation at the beginning and move on to the conclusion, the final output should include both the \
introduction and the conclusion.
3. Focus on Completeness:
- For text-based documents: Ensure that the extracted content is comprehensive and represents the full document \
or section as discussed in the conversation.
You should output a JSON object with two entries:
- "thought" (str): Output your thought process when extracting the final content.
1. How do different parts of the conversation contribute to the final output?
2. How do you make sure you included the most updated and complete information?
3. How do you make sure you did not include any information that is not necessary?
- "final_completion" (str): The final and complete version of the document extracted from the conversation.
Note:
1. If there are multiple lines, you should use triple quotes (""") to wrap the content. For example, \
"final_completion": """first line.
second line.""" or "thought": """first line;
second line.""".
2. In the "final_completion" entry, replace all double quotes (") with single quotes (') to prevent JSON formatting \
issues. For example, you can output "final_completion": "'Hello World' is a common phrase."
Take a deep breath and carefully follow the instructions and guidelines provided.
'''
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
# Check if litellm is available, fallback to openai if not
try:
import litellm
use_litellm = True
except ImportError:
# litellm not found, falling back to openai
import openai
use_litellm = False
chat_history = parse_messages(messages, strip_sys_prompt=True)
prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format(chat_history=chat_history)
if use_litellm:
full_response = (
(
await litellm.acompletion(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
else:
client = openai.AsyncOpenAI() # Assumes API key is set in environment
full_response = (
(
await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
full_response = extract_json(full_response)
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
assert {"final_completion", "thought"}.issubset(full_response.keys()), (
f"Expected keys not found from {full_response.keys()}"
)
final_completion = full_response.pop("final_completion")
bleu = sentence_bleu([ground_truth], final_completion)
return float(bleu)

View File

@ -0,0 +1,108 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from recipe.collabllm.utils import extract_json, parse_messages
INTERACTIVITY_PROMPT = '''You are a helpful and meticulous conversation evaluator. \
Your task is to evaluate the interactivity of the responses provided by an AI assistant \
to user questions in a given conversation:
<|The Start of the Conversation to be Evaluated|>
{chat_history}
<|The End of the Conversation to be Evaluated|>
You should assess the assistant's engagement, clarity, and ability to understand the user's needs. \
Give a float number between 0 and 1.
Scoring Criteria:
- Let U = user understanding & response clarity ∈ [0,1]
- 1.0 = Fully understands the user's intent and gives a clear answer.
- 0.7 = Mostly understands and the answer is generally clear.
- 0.3 = Partially misunderstands or the answer is hard to follow.
- 0.0 = Misunderstands the intent and gives an unclear or irrelevant answer.
- Let Q = clarification in [0,1]
- 1.0 = Asks precise, necessary clarifying questions when needed.
- 0.7 = Asks somewhat helpful but incomplete clarifications.
- 0.3 = Only asks generic questions (e.g., “Does that help?”).
- 0.0 = Asks no clarifying questions when needed.
- Let S = suggestion helpfulness in [0,1]
- 1.0 = Provides useful, actionable suggestions.
- 0.7 = Suggestions are somewhat helpful but limited.
- 0.3 = Suggestions are vague or generic.
- 0.0 = No suggestions when they would clearly help.
score = average([U, Q, S])
Output format (JSON):
{{
"thought": "<How interactive is the assistant?>",
"interactivity": <score>
}}
Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \
Use " or """ to wrap up the thought. You should not use other triple quotes inside the "thought" field. \
Instead you should use single quotes to avoid JSON escape issues.
Your evaluation:
'''
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
# Check if litellm is available, fallback to openai if not
try:
import litellm
use_litellm = True
except ImportError:
# litellm not found, falling back to openai
import openai
use_litellm = False
chat_history = parse_messages(messages, strip_sys_prompt=True)
prompt = INTERACTIVITY_PROMPT.format(chat_history=chat_history)
if use_litellm:
full_response = (
(
await litellm.acompletion(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
else:
client = openai.AsyncOpenAI() # Assumes API key is set in environment
full_response = (
(
await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
full_response = extract_json(full_response)
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
assert {"interactivity", "thought"}.issubset(full_response.keys()), (
f"Expected keys not found from {full_response.keys()}"
)
interactivity = full_response.pop("interactivity")
return float(interactivity)

View File

@ -0,0 +1,139 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bigcodebench.eval import untrusted_check
from recipe.collabllm.utils import extract_json, parse_messages
EXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \
Your task is to extract the final and complete version of a code function {entry_point} that was generated \
during a multiturn conversation between a user and a chat assistant. \
The extracted content should reflect the final and comprehensive response provided by the \
assistant based on the users request.
You will be provided with the task and the conversation:
<|The Start of The Task|>
{single_turn_prompt}
<|The End of The Task|>
<|The Start of The Conversation|>
{chat_history}
<|The End of The Conversation|>
Instructions for Extraction:
1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts of \
the content provided by the assistant. This may include:
- Different parts of the code snippet, function, class, or script.
2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \
ensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \
output that incorporates all modifications and additions made during the conversation. For example, if the assistant \
writes a function at the beginning and changes a part, the final output should take the modification into account.
3. Focus on Completeness:
- For code: Extract a complete and functional code snippet, including all necessary components such as imports, \
functions, classes, and any other essential elements. The code should be runnable, but you do not need to \
include any testing examples including the contents after `if __name__ == "__main__":`. Only the function code \
is required.
You should output a JSON object with two entries:
- "thought" (str): Output your thought process when extracting the final content.
1. How do different parts of the conversation contribute to the final output?
2. How do you make sure you included the most updated and complete information?
3. How do you make sure you did not include any information that is not necessary?
- "final_completion" (str): The final and complete version of the code extracted from the conversation. \
Rename main function name for the task to {entry_point} if needed. Remove any comments wrapped by """.
Note:
1. If there are multiple lines, you should use triple quotes (""") to wrap the content. For example, \
"final_completion": """first line.
second line.""" or "thought": """first line;
second line.""". You should not use other triple quotes inside.
2. In the "final_completion" entry, replace all double quotes (") with single quotes (') to prevent JSON formatting \
issues. For example, you can output "final_completion": "'Hello World' is a common phrase."
Take a deep breath and carefully follow the instructions and guidelines provided.
'''
async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
# Check if litellm is available, fallback to openai if not
try:
import litellm
use_litellm = True
except ImportError:
# litellm not found, falling back to openai
import openai
use_litellm = False
chat_history = parse_messages(messages, strip_sys_prompt=True)
prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format(
chat_history=chat_history,
single_turn_prompt=extra_info["interaction_kwargs"]["single_turn_prompt"],
entry_point=extra_info["single_turn_metadata"]["entry_point"],
)
if use_litellm:
full_response = (
(
await litellm.acompletion(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
else:
client = openai.AsyncOpenAI() # Assumes API key is set in environment
full_response = (
(
await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
**kwargs,
)
)
.choices[0]
.message.content
)
full_response = extract_json(full_response)
assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}"
assert {"final_completion", "thought"}.issubset(full_response.keys()), (
f"Expected keys not found from {full_response.keys()}"
)
final_completion = full_response.pop("final_completion")
metadata = extra_info["single_turn_metadata"]
res = untrusted_check(
final_completion,
metadata["test"],
metadata["entry_point"],
max_as_limit=300 * 1024,
max_data_limit=300 * 1024,
max_stack_limit=300 * 1024,
min_time_limit=60,
gt_time_limit=60,
)
passed = res[0] == "pass"
# info = res[1] # for printing extra info
return float(passed)

View File

@ -0,0 +1,26 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
prompt = extra_info["prompt"]
# Calculate the token penalty based on the length of the prompt
future_conv = messages[len(prompt) :]
# simple length estimation
total_tokens = sum(len(m.content.split()) for m in future_conv)
return total_tokens

View File

@ -0,0 +1,239 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
# available datasets:
# math-hard(-large), medium(-large), bigcodebench(-large)
# to create your own dataset, refer to https://github.com/Wuyxin/collabllm
DATASET=math-hard-large
python recipe/collabllm/process_dataset.py \
--dataset collabllm/collabllm-multiturn-$DATASET \
--local_dir $HOME/data/collabllm-$DATASET \
--dataset_type sft
python recipe/collabllm/process_dataset.py \
--dataset collabllm/collabllm-multiturn-$DATASET \
--local_dir $HOME/data/collabllm-$DATASET \
--dataset_type rl
Preprocess collabllm/collabllm-multiturn-math-hard into (ground_truth, extra_info).
- ground_truth: picked from --prefer_field (default: single_turn_completion),
falling back to --fallback_field (default: completion)
- extra_info: a shallow copy of the original example plus bookkeeping fields
- reward_model: {"style": "rule", "ground_truth": ground_truth}
Saves one parquet per split into --local_dir and a small JSON preview.
"""
import argparse
import json
import os
import uuid
from typing import Any, Optional
from datasets import Dataset, concatenate_datasets, load_dataset
SYSTEM_PROMPT = """The assistant is designed to be helpful, proactive, and highly interactive.
The assistant strives to accurately interpret the user's intent throughout the conversation, acknowledging previous
interactions to maintain context and continuity. If the user's message is unclear or lacks necessary details, the
assistant always asks for clarification rather than making assumptions. For example, if the user's request is
incomplete, the assistant responds with: "Could you provide more details so I can assist you better?"
The assistant asks specific follow-up questions and offers suggestions based on the user's needs, avoiding vague or
generic prompts. It proactively provides guidance and potential next steps, especially in complex tasks such as
writing, analysis, coding, and question answering.
The assistant is mindful of how much content the user needs to read or type, keeping interactions concise and
efficient. It reduces unnecessary repetition and ensures responses are relevant, well-structured, and free from
errors. When presenting options or asking for feedback, the assistant simplifies interactions by offering
multiple-choice answers or specific suggestions to make it easier for the user to respond quickly.
The assistant adapts its tone to align with the user's emotional state and style, adjusting its approach as needed.
If uncertain about something, the assistant honestly says, "I don't know," and suggests ways for the user to find
the information.
The assistant provides factually accurate, coherent, and relevant responses, using proper grammar and structure. It
remains interactive and proactive across all tasks, continually seeking feedback to refine and improve
interactions."""
# Required fields: "prompt", "ground_truth", "extra_info"
# In "extra_info" dict:
# (1) Rquired: "single_turn_prompt", which is the specific problem used to inform the user simulator,
# (2) Optional: "task_desc" (a short task description),
# (3) Optional: other fields for customized reward computation
def collapse_example(example: dict[str, Any]) -> dict[str, Any]:
if "prompt" not in example:
raise ValueError("Missing required 'prompt' field.")
ground_truth = (
example.get("ground_truth") or example.get("single_turn_completion") or example.get("completion") or ""
)
extra_info = {}
for k, v in example.items():
if k in ("prompt", "ground_truth", "extra_info"):
continue
extra_info.setdefault(k, v) # keep extra_info values if keys overlap
# make sure extra_info has the required fields
assert "single_turn_prompt" in extra_info, "Missing 'single_turn_prompt' in extra_info."
# add system prompt as the beginning of the list
example["prompt"] = [{"role": "system", "content": SYSTEM_PROMPT}] + example["prompt"]
extra_info.setdefault("prompt", example["prompt"]) # save the original prompt
extra_info.setdefault(
"interaction_kwargs",
{
"name": "collabllm",
"single_turn_prompt": extra_info.pop("single_turn_prompt"),
"task_desc": extra_info.pop("task_desc", "general ask-for-assistance task"),
},
)
return {
"prompt": example["prompt"],
"ground_truth": ground_truth,
"raw_prompt": example["prompt"], # save the original prompt
"extra_info": extra_info,
"reward_model": {"style": "rule", "ground_truth": ground_truth},
"data_source": "collabllm",
"agent_name": "collabllm_agent",
"index": str(uuid.uuid4()),
}
# ---------- IO helpers ----------
def save_parquet(ds_split: Dataset, filename: str, out_dir: str) -> None:
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, f"{filename}.parquet")
ds_split.to_parquet(path)
print(f"[OK] Wrote {filename}.parquet → {path} ({len(ds_split)} rows)")
def maybe_copy_to_hdfs(local_dir: str, hdfs_dir: Optional[str]) -> None:
if not hdfs_dir:
return
try:
from verl.utils.hdfs_io import copy, makedirs # type: ignore
except Exception as e:
print(f"[WARN] Skipping HDFS copy (verl not available): {e}")
return
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
print(f"[OK] Copied {local_dir}{hdfs_dir}")
# ---------- Main ----------
def main():
ap = argparse.ArgumentParser()
ap.add_argument(
"--dataset", default="collabllm/collabllm-multiturn-math-hard", help="HF dataset path or local dir/file."
)
ap.add_argument("--task_desc", default="solving math problems", help="Task description for the dataset.")
ap.add_argument("--local_dir", default="~/data/collabllm-math-hard", help="Output directory.")
ap.add_argument("--hdfs_dir", default=None, help="Optional HDFS destination (requires verl).")
ap.add_argument(
"--validation_size", type=float, default=0.1, help="Validation split size (fraction or absolute int)."
)
ap.add_argument("--seed", type=int, default=42, help="Random seed for splitting.")
ap.add_argument("--num_proc", type=int, default=1, help="Parallel workers for map().")
ap.add_argument("--dataset_type", default="rl", choices=["rl", "sft"], help="Type of dataset (e.g., 'rl', 'sft').")
args = ap.parse_args()
out_dir = os.path.expanduser(args.local_dir)
os.makedirs(out_dir, exist_ok=True)
print(f"[INFO] Loading dataset: {args.dataset}")
ds_dict = load_dataset(args.dataset)
parts = list(ds_dict.values())
ds_all: Dataset = parts[0] if len(parts) == 1 else concatenate_datasets(parts)
# Dataset({
# features: ['prompt', 'completion', 'conv_id', 'score', 'single_turn_prompt',
# 'single_turn_completion', 'single_turn_metadata', 'turn_id', 'sessions', 'rewards'],
# num_rows: xxx
# })
if args.dataset_type == "rl":
# If multiple splits exist, merge them before collapsing/splitting.
ds_all = ds_all.map(lambda x: {"task_desc": args.task_desc}, num_proc=args.num_proc)
print(f"[INFO] Collapsing to formatted fields on {len(ds_all)} rows…")
ds_all = ds_all.map(
function=collapse_example,
remove_columns=ds_all.column_names,
num_proc=args.num_proc,
)
def dedup_by_prompt(dataset):
seen = set()
unique_rows = []
for ex in dataset:
prompt_key = json.dumps(ex["prompt"], sort_keys=True, ensure_ascii=False)
if prompt_key not in seen:
seen.add(prompt_key)
unique_rows.append(ex)
return Dataset.from_list(unique_rows)
ds_all = dedup_by_prompt(ds_all)
elif args.dataset_type == "sft":
df = ds_all.to_pandas()
# Sort so that within each conv_id the highest turn_id is first,
# and if multiple rows share the same turn_id, the highest score comes first
df = df.sort_values(["conv_id", "turn_id", "score"], ascending=[True, False, False])
# Keep only the top row per conv_id
df = df.drop_duplicates(subset="conv_id", keep="first")
# Back to HF Dataset
ds_all = Dataset.from_pandas(df, preserve_index=False)
# Append assistant response into prompt list
def append_completion(example):
example["prompt"] = (
[{"role": "system", "content": SYSTEM_PROMPT}]
+ example["prompt"]
+ [{"role": "assistant", "content": example["completion"]}]
)
return example
ds_all = ds_all.map(append_completion)
# Keep only prompt column
cols_to_remove = [col for col in ds_all.column_names if col != "prompt"]
ds_all = ds_all.remove_columns(cols_to_remove)
print(f"[INFO] Splitting with validation_size={args.validation_size}, seed={args.seed}")
split = ds_all.train_test_split(test_size=args.validation_size, seed=args.seed, shuffle=True)
train_ds, val_ds = split["train"], split["test"]
print(train_ds, val_ds)
save_parquet(train_ds, f"{args.dataset_type}_train", out_dir)
save_parquet(val_ds, f"{args.dataset_type}_validation", out_dir)
maybe_copy_to_hdfs(local_dir=out_dir, hdfs_dir=args.hdfs_dir)
print(f"[DONE] {args.dataset_type}_train.parquet and {args.dataset_type}_validation.parquet written.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,95 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import importlib.util
import os
import sys
import litellm
import torch
async def conversation_level_reward_func(
data_source, messages, ground_truth, extra_info, metrics, **kwargs
) -> torch.Tensor:
"""
Async version of conversation-level reward function.
Apply conversation-level reward function to the future interactions between the user simulator
and policy model, which are generated from `verl/interactions/collabllm_interation.py`
"""
num_retries = kwargs.get("num_retries", 6)
rewards = {}
for metric in metrics:
current_dir = os.path.dirname(os.path.abspath(__file__))
metric_file_path = os.path.join(current_dir, f"metrics/{metric}.py")
if not os.path.exists(metric_file_path):
print(f"Error: Metric file '{metric_file_path}' not found. Assigning 0 to metric '{metric}'.")
rewards[metric] = 0.0
continue
spec = importlib.util.spec_from_file_location(f"metric_{metric}", metric_file_path)
if spec is None:
print(f"Error: Could not create spec for metric '{metric}'. Assigning 0 to metric '{metric}'.")
rewards[metric] = 0.0
continue
module = importlib.util.module_from_spec(spec)
try:
sys.modules[f"metric_{metric}"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
except Exception as e:
print(f"Error loading metric module from '{metric_file_path}': {e}. Assigning 0 to metric '{metric}'.")
rewards[metric] = 0.0
continue
# Assume each metric file has a compute_score function
if not hasattr(module, "compute_score"):
print(
f"Error: Function 'compute_score' not found in '{metric_file_path}'. Assigning 0 to metric '{metric}'."
)
rewards[metric] = 0.0
continue
compute_score_fn = module.compute_score
# Retry mechanism for calling the metric function
for attempt in range(num_retries):
try:
# Call the metric function (await if it's async)
if asyncio.iscoroutinefunction(compute_score_fn):
rewards[metric] = await compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs)
else:
rewards[metric] = compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs)
break # Success, exit retry loop
except Exception as e:
if attempt == num_retries - 1: # Last attempt
print(
f"Error: Failed to compute metric '{metric}' after {num_retries} attempts. "
f"Last error: {e}. Assigning 0 to metric '{metric}'."
)
rewards[metric] = 0.0
else:
print(f"Attempt {attempt + 1} failed for metric '{metric}': {e}. Retrying...")
if isinstance(e, litellm.RateLimitError):
await asyncio.sleep(max(2**attempt, 60)) # Exponential backoff
# Return dict with metric names as keys
return {metric: torch.tensor(reward, dtype=torch.float32) for metric, reward in rewards.items()}

View File

@ -0,0 +1,75 @@
# Usage: sh recipe/collabllm/train_rl_collabllm.sh <optional resume path>
set -x
PROJECT_DIR="$(pwd)"
export VLLM_USE_V1=1
RESUME_PATH="${1:-}"
if [ -z "$RESUME_PATH" ]; then
RESUME_PATH=null
fi
DATASET=math-hard-large
PROJECT_DIR="$(pwd)"
python3 -m verl.trainer.main_ppo \
trainer.val_before_train=False \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/collabllm-$DATASET/rl_train.parquet \
data.val_files=$HOME/data/collabllm-$DATASET/rl_validation.parquet \
reward_model.reward_manager=collabllm \
+reward_model.reward_kwargs.metric_weights.accuracy=1 \
+reward_model.reward_kwargs.metric_weights.interactivity=1 \
+reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
+reward_model.reward_kwargs.llm_judge_kwargs.model=gpt-4o-mini \
+reward_model.reward_kwargs.llm_judge_kwargs.max_tokens=2048 \
+reward_model.reward_kwargs.llm_judge_kwargs.temperature=0 \
data.train_batch_size=16 \
data.max_prompt_length=8196 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path="Qwen/Qwen2.5-7B-Instruct" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.rollout.multi_turn.enable=true \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.multi_turn.max_user_turns=2 \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3 \
actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts=3 \
actor_rollout_ref.rollout.trace.token2text=True \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console", "wandb"]' \
trainer.project_name=verlxcollabllm \
trainer.experiment_name=collabllm-qwen2.5-7B-$DATASET \
trainer.nnodes=1 \
trainer.n_gpus_per_node=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
trainer.save_freq=100 \
trainer.test_freq=10 \
trainer.total_epochs=20 \
custom_reward_function.path=recipe/collabllm/reward_function.py \
custom_reward_function.name=conversation_level_reward_func \
actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/recipe/collabllm/config/collabllm_interaction_config.yaml" \
trainer.resume_from_path=$RESUME_PATH

View File

@ -0,0 +1,32 @@
#!/bin/bash
set -x
if [ "$#" -lt 1 ]; then
echo "Usage: sft_train_collabllm.sh [<nproc_per_node> other_configs...]"
exit 1
fi
nproc_per_node=$1
# Shift the arguments so $@ refers to the rest
shift 1
DATASET=math-hard-large
torchrun --nnodes=1 --nproc_per_node=$nproc_per_node \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/collabllm-$DATASET/sft_train.parquet \
data.val_files=$HOME/data/collabllm-$DATASET/sft_validation.parquet \
data.multiturn.enable=true \
data.multiturn.messages_key=prompt \
optim.lr=1e-6 \
data.train_batch_size=64 \
data.micro_batch_size_per_gpu=2 \
data.max_length=8196 \
model.partial_pretrain=Qwen/Qwen2.5-7B-Instruct \
trainer.project_name=collabllm-sft-$DATASET \
trainer.experiment_name=collabllm-sft-qwen2.5-7B-$DATASET \
trainer.logger=console \
trainer.total_epochs=3 $@ \
ulysses_sequence_parallel_size=1 \
use_remove_padding=true $@

280
recipe/collabllm/utils.py Normal file
View File

@ -0,0 +1,280 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
def parse_messages(messages, strip_sys_prompt=True):
"""
Args:
messages: List[dict]
List of dictionaries with keys 'role' and 'content'
Example: messages = [{'role': 'user', 'content': 'Hello!'},
{'role': 'assistant', 'content': 'Hi!'}, ...]
"""
if messages is None:
return ""
if strip_sys_prompt:
messages = strip_system_prompt(messages)
chat = "\n".join(f"**{m.role.capitalize()}**: {m.content}" for m in messages)
return chat
def strip_system_prompt(messages):
"""
Args:
messages: List[dict]
List of dictionaries with keys 'role' and 'content'
Example: messages = [{'role': 'user', 'content': 'Hello!'},
{'role': 'assistant', 'content': 'Hi!'}, ...]
"""
return [msg for msg in messages if msg.role != "system"]
def extract_json(s):
def convert_value(value):
true_values = {"true": True, "false": False, "null": None}
value_lower = value.lower()
if value_lower in true_values:
return true_values[value_lower]
try:
if "." in value or "e" in value.lower():
return float(value)
else:
return int(value)
except ValueError:
return value # Return as string if not a number
def parse_number(s, pos):
start = pos
while pos < len(s) and s[pos] in "-+0123456789.eE":
pos += 1
num_str = s[start:pos]
try:
if "." in num_str or "e" in num_str.lower():
return float(num_str), pos
else:
return int(num_str), pos
except ValueError:
logger.error(f"Invalid number at position {start}: {num_str}")
raise
def skip_whitespace(s, pos):
while pos < len(s) and s[pos] in " \t\n\r":
pos += 1
return pos
def parse_string(s, pos):
quote_char = s[pos]
assert quote_char in ('"', "'")
pos += 1
result = ""
while pos < len(s):
c = s[pos]
if c == "\\":
pos += 1
if pos >= len(s):
raise ValueError("Invalid escape sequence")
c = s[pos]
escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char}
result += escape_sequences.get(c, c)
elif c == quote_char:
pos += 1
# Attempt to convert to a number if possible
converted_value = convert_value(result)
return converted_value, pos
else:
result += c
pos += 1
raise ValueError("Unterminated string")
def parse_key(s, pos):
pos = skip_whitespace(s, pos)
if s[pos] in ('"', "'"):
key, pos = parse_string(s, pos)
return key, pos
else:
raise ValueError(f"Expected string for key at position {pos}")
def parse_object(s, pos):
obj = {}
assert s[pos] == "{"
pos += 1
pos = skip_whitespace(s, pos)
while pos < len(s) and s[pos] != "}":
pos = skip_whitespace(s, pos)
key, pos = parse_key(s, pos)
pos = skip_whitespace(s, pos)
if pos >= len(s) or s[pos] != ":":
raise ValueError(f'Expected ":" at position {pos}')
pos += 1
pos = skip_whitespace(s, pos)
value, pos = parse_value(s, pos)
obj[key] = value
pos = skip_whitespace(s, pos)
if pos < len(s) and s[pos] == ",":
pos += 1
pos = skip_whitespace(s, pos)
elif pos < len(s) and s[pos] == "}":
break
elif pos < len(s) and s[pos] != "}":
raise ValueError(f'Expected "," or "}}" at position {pos}')
if pos >= len(s) or s[pos] != "}":
raise ValueError(f'Expected "}}" at position {pos}')
pos += 1
return obj, pos
def parse_array(s, pos):
lst = []
assert s[pos] == "["
pos += 1
pos = skip_whitespace(s, pos)
while pos < len(s) and s[pos] != "]":
value, pos = parse_value(s, pos)
lst.append(value)
pos = skip_whitespace(s, pos)
if pos < len(s) and s[pos] == ",":
pos += 1
pos = skip_whitespace(s, pos)
elif pos < len(s) and s[pos] == "]":
break
elif pos < len(s) and s[pos] != "]":
raise ValueError(f'Expected "," or "]" at position {pos}')
if pos >= len(s) or s[pos] != "]":
raise ValueError(f'Expected "]" at position {pos}')
pos += 1
return lst, pos
def parse_triple_quoted_string(s, pos):
if s[pos : pos + 3] == "'''":
quote_str = "'''"
elif s[pos : pos + 3] == '"""':
quote_str = '"""'
else:
raise ValueError(f"Expected triple quotes at position {pos}")
pos += 3
result = ""
while pos < len(s):
if s[pos : pos + 3] == quote_str:
pos += 3
# Attempt to convert to a number if possible
converted_value = convert_value(result)
return converted_value, pos
else:
result += s[pos]
pos += 1
raise ValueError("Unterminated triple-quoted string")
def parse_value(s, pos):
pos = skip_whitespace(s, pos)
if pos >= len(s):
raise ValueError("Unexpected end of input")
if s[pos] == "{":
return parse_object(s, pos)
elif s[pos] == "[":
return parse_array(s, pos)
elif s[pos : pos + 3] in ("'''", '"""'):
return parse_triple_quoted_string(s, pos)
elif s[pos] in ('"', "'"):
return parse_string(s, pos)
elif s[pos : pos + 4].lower() == "true":
return True, pos + 4
elif s[pos : pos + 5].lower() == "false":
return False, pos + 5
elif s[pos : pos + 4].lower() == "null":
return None, pos + 4
elif s[pos] in "-+0123456789.":
return parse_number(s, pos)
else:
raise ValueError(f"Unexpected character at position {pos}: {s[pos]}")
json_start = s.index("{")
json_end = s.rfind("}")
s = s[json_start : json_end + 1]
s = s.strip()
result, pos = parse_value(s, 0)
pos = skip_whitespace(s, pos)
if pos != len(s):
raise ValueError(f"Unexpected content at position {pos}")
return result
def remove_think_block(msg: dict):
"""
remove <think>.*?</think> from content
"""
if "content" in msg and isinstance(msg["content"], str):
msg["content"] = re.sub(r"<think>.*?</think>", "", msg["content"], flags=re.DOTALL).strip()
return msg
def is_valid_messages(msg: dict) -> bool:
"""
check if is valid messages, including:
1. <think> is paried with </think>
2. is not empty inside and outside <think>
3. is not nested, and at most one <think> block is allowed.
4. can not be empty if remove ending "<|im_end|>"
"""
content = msg.get("content")
if not isinstance(content, str):
return True
# Base case: empty or whitespace-only content is invalid.
if not content.strip():
return False
num_think_open = content.count("<think>")
num_think_close = content.count("</think>")
# Rule 1: Check for paired tags.
if num_think_open != num_think_close:
return False
# Rule 3: Allow at most one think block.
if num_think_open > 1:
return False
# Case 1: No <think> blocks.
if num_think_open == 0:
visible_content = content
# Case 2: Exactly one <think> block.
else:
# Rule 2: Check for empty content inside the think block.
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
if not match or not match.group(1).strip():
return False
# The "visible" content is what's outside the think block.
visible_content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
visible_content = visible_content.strip()
# Rule 4 & 2 (outside): Check if visible content is empty after handling <|im_end|>.
if visible_content.endswith("<|im_end|>"):
visible_content = visible_content[: -len("<|im_end|>")]
if not visible_content.strip():
return False
return True

View File

@ -327,27 +327,47 @@ class RewardManagerWorker:
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)
self.rm_executor = rm_executor
self.loop = asyncio.get_event_loop()
def compute_score(
async def compute_score(
self,
data: DataProto,
) -> dict:
"""Compute reward score for agent loop output.
NOTE: Since `reward_manager.__call__` is blocking function, we run it in thread pool to
compute multiple samples in parallel.
Args:
data: reward function input
Returns:
dict: Reward score and reward extra info.
"""
result = await self.loop.run_in_executor(
None,
self.reward_wrapper,
data,
True, # return_dict
)
reward_score = result["reward_tensor"].sum(dim=-1).item()
reward_extra_info = {k: v[0] for k, v in result.get("reward_extra_info", {}).items()}
return {"reward_score": reward_score, "reward_extra_info": reward_extra_info}
def reward_wrapper(self, data: DataProto, return_dict=False) -> torch.Tensor:
"""Assemble reward functions and reward model into one function and expose it to the event loop
Args:
return_dict: whether return as dict
data: DataProto from compute reward score
Returns:
torch.Tensor: Reward score tensor.
"""
if self.rm_executor is not None:
res = ray.get(self.rm_executor.submit_task.remote(data))
data = data.union(res)
result = self.reward_manager(data, return_dict=True)
reward_score = result["reward_tensor"].sum(dim=-1).item()
reward_extra_info = {k: v[0] for k, v in result.get("reward_extra_info", {}).items()}
return {"reward_score": reward_score, "reward_extra_info": reward_extra_info}
return self.reward_manager(data, return_dict)
@ray.remote
@ -597,6 +617,11 @@ class AgentLoopWorker:
**{k: np.array([v]) for k, v in kwargs.items()},
"__num_turns__": np.array([output.num_turns]),
}
extra_fields = {}
for key, val in output.extra_fields.items():
extra_fields[key] = np.array([val], dtype=object)
non_tensor_batch.update(extra_fields)
data = DataProto(
batch=batch,
non_tensor_batch=non_tensor_batch,

View File

@ -133,7 +133,6 @@ class ToolAgentLoop(AgentLoopBase):
)
interaction = self.interaction_map[interaction_name]
await interaction.start_interaction(request_id, **interaction_kwargs)
# Create AgentData instance to encapsulate all state
agent_data = AgentData(
messages=messages,
@ -152,12 +151,10 @@ class ToolAgentLoop(AgentLoopBase):
state = await self._handle_pending_state(agent_data, sampling_params)
elif state == AgentState.GENERATING:
state = await self._handle_generating_state(agent_data, sampling_params)
agent_data.assistant_turns += 1
elif state == AgentState.PROCESSING_TOOLS:
state = await self._handle_processing_tools_state(agent_data)
elif state == AgentState.INTERACTING:
state = await self._handle_interacting_state(agent_data)
agent_data.user_turns += 1
else:
logger.error(f"Invalid state: {state}")
state = AgentState.TERMINATED
@ -209,7 +206,9 @@ class ToolAgentLoop(AgentLoopBase):
)
return AgentState.GENERATING
async def _handle_generating_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState:
async def _handle_generating_state(
self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False
) -> AgentState:
"""Handle the generating state: generate model response and check for tool calls."""
add_messages: list[dict[str, Any]] = []
@ -221,6 +220,7 @@ class ToolAgentLoop(AgentLoopBase):
image_data=agent_data.image_data,
)
agent_data.assistant_turns += 1
agent_data.response_ids = output.token_ids
agent_data.prompt_ids += agent_data.response_ids
agent_data.response_mask += [1] * len(agent_data.response_ids)
@ -228,7 +228,7 @@ class ToolAgentLoop(AgentLoopBase):
agent_data.response_logprobs += output.log_probs
# Check termination conditions
if len(agent_data.response_mask) >= self.response_length:
if not ignore_termination and len(agent_data.response_mask) >= self.response_length:
return AgentState.TERMINATED
if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns:
return AgentState.TERMINATED
@ -241,7 +241,7 @@ class ToolAgentLoop(AgentLoopBase):
# Handle interaction if needed
if self.interaction_config_file:
assistant_message = await self.loop.run_in_executor(
None, lambda: self.tokenizer.decode(agent_data.response_ids)
None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True)
)
add_messages.append({"role": "assistant", "content": assistant_message})
agent_data.messages.extend(add_messages)
@ -368,8 +368,10 @@ class ToolAgentLoop(AgentLoopBase):
) = await agent_data.interaction.generate_response(
agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs
)
agent_data.user_turns += 1
add_messages: list[dict[str, Any]] = [{"role": "user", "content": interaction_responses}]
agent_data.messages.extend(add_messages)
if reward is not None:
agent_data.turn_scores.append(reward)
@ -400,6 +402,7 @@ class ToolAgentLoop(AgentLoopBase):
if agent_data.response_logprobs:
agent_data.response_logprobs += [0.0] * len(response_ids)
# double check prompt
# Check termination condition
if should_terminate_sequence:
return AgentState.TERMINATED

View File

@ -0,0 +1,374 @@
# Copyright 2024 CollabLLM Ltd. and/or its affiliates
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import logging
import os
from typing import Any, Optional
from uuid import uuid4
from recipe.collabllm.utils import remove_think_block
from verl.utils.rollout_trace import rollout_trace_op
from .base import BaseInteraction
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
USER_PROMPT_TEMPLATE = """You are role-playing as a human USER interacting with an AI collaborator to complete a specific task. Your goal is to generate realistic, natural responses that a user might give in this scenario.
## Input Information:
You will be provided with:
- Task Description: The type of task you are trying to accomplish.
- Complete Prompt or Reference Goal: This field may include the complete user request/query or a reference answer to user's request. Use this field to understand the user's intent, requirements, or what would count as a satisfactory outcome.
- Chat History: The ongoing conversation between you (as the user) and the AI
Inputs:
<|The Start of Task Description (Not visible to the AI)|>
{task_desc}
<|The End of Task Description|>
<|The Start of Complete Prompt or Reference Goal (Not visible to the AI)|>
{single_turn_prompt}
<|The End of Complete Prompt or Reference Goal|>
<|The Start of Chat History|>
{chat_history}
<|The End of Chat History|>
## Guidelines:
- Stay in Character: Role-play as a human USER. You are NOT an AI. Maintain a consistent personality throughout the chat.
- Minimize Effort: IMPORTANT! As a user, avoid being too detailed in your responses. Provide vague or incomplete demands in the early stages of the conversation to minimize your effort. Let the AI ask for clarification rather than providing everything upfront.
- Knowledge Background: Reflect the user's knowledge level in the role-playing. If the user is less knowledgeable about a task, they might not notice incorrect statements. Ask questions that demonstrate your current understanding and areas of confusion.
- Occasionally Make Mistakes: Real-world users might misspell words, provide incorrect dates, give wrong information, or ask unclear questions. Simulate this behavior to reflect natural interactions.
- Mention Personal Preferences: Include preferences or constraints that might influence your requests or responses. For example, "I prefer short answers," "I need this done quickly," or "I like detailed comments in code."
- Goal-Oriented: Keep the chat focused on your intent. Avoid small talk or digressions. Redirect the chat back to the main objective if it starts to stray.
## Output Format:
You should output a JSON object with three entries:
- "current_answer" (str): Briefly summerize the AI's current solution to the task.
- "thought" (str): Output your thought process as a user deciding what to say next. Consider:
1. Have you obtained a satisfactory solution from the AI? If yes, you can terminate this chat.
2. If not, what specific part of the problem or solution are you struggling with?
3. Has the AI asked you to perform a task or answer a question? If so, how should you approach it?
4. Are you noticing any patterns or potential misunderstandings that need clarification?
5. If you're stuck, how can you phrase your question to get the most helpful response while demonstrating your current understanding?
- "response" (str): Based on your thought process, respond to the AI as the user you are role-playing. Stop immediately when the user's response is completed.
## Important Notes:
- Respond Based on Previous Messages: Your responses should be based on the context of the current chat history. Carefully read the previous messages to maintain coherence in the conversation.
- Conversation Flow: If "Current Chat History" is empty, start the conversation from scratch with an initial request. Otherwise, continue based on the existing conversation.
- Don't Copy Input Directly: Use the provided information for understanding context only. Avoid copying target queries or any provided information directly in your responses.
- Completion Signal: Use "{termination_signal}" as your response when you believe your goal has been solved or if you determine the AI cannot help further.
- Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured.
Remember to stay in character as a user throughout your response, and follow the instructions and guidelines carefully.""" # noqa
class CollabLLMInteraction(BaseInteraction):
"""A demo interaction for calculating the reward of CollabLLM.
- `start_interaction`: start a interaction instance for a trajectory.
- `generate_response`: generate the response of the assistant.
- `calculate_score`: calculate the score of the interaction.
- `finalize_interaction`: finalize the interaction instance.
"""
def __init__(self, config: dict):
super().__init__(config)
_config = copy.deepcopy(config)
_config.pop("enable_log", None)
self.name = _config.pop("name")
self.user_model = _config.pop("user_model")
self.termination_signal = _config.pop("termination_signal", TERMINATION_SIGNAL)
self.num_retries = _config.pop("num_retries", 3)
self.user_model_kwargs = _config
self._instance_dict = {}
async def start_interaction(
self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs
) -> str:
if instance_id is None:
instance_id = str(uuid4())
self._instance_dict[instance_id] = {
"response": "",
"ground_truth": ground_truth,
"reward": 0.0,
}
self.interaction_kwargs = kwargs
assert "single_turn_prompt" in kwargs, "single_turn_prompt is required in interaction_kwargs"
return instance_id
@rollout_trace_op
async def generate_response(
self, instance_id: str, messages: list[dict[str, Any]], **kwargs
) -> tuple[bool, str, float, dict]:
assert messages[-1]["role"] in ["system", "assistant"], (
"Last message input to the user model must be from system or assistant role"
)
import litellm
chat_history = self._parse_messages(messages, strip_sys_prompt=True)
prompt = USER_PROMPT_TEMPLATE.format(
task_desc=self.interaction_kwargs.get("task_desc", "general assistance task"),
single_turn_prompt=self.interaction_kwargs["single_turn_prompt"],
chat_history=chat_history,
termination_signal=self.termination_signal,
)
response = ""
for i in range(self.num_retries):
try:
full_response = (
(
await litellm.acompletion(
model=self.user_model,
messages=[{"role": "user", "content": prompt}],
**self.user_model_kwargs,
)
)
.choices[0]
.message.content
)
except litellm.RateLimitError as e:
logger.warning(f"[CollabLLMInteraction] hit RateLimitError: {e}. Retrying...")
await asyncio.sleep(max(2**i, 60))
continue
except Exception as e:
logger.exception(f"An unexpected error occurred in CollabLLMAgentLoop: {e}")
continue
try:
if isinstance(full_response, str):
full_response = extract_json(full_response)
except Exception as e:
logger.warning(f"[CollabLLMInteraction] Error extracting JSON: {e}. Retrying...")
continue
if isinstance(full_response, dict):
keys = full_response.keys()
if {"current_answer", "thought", "response"}.issubset(keys):
response = full_response.pop("response")
if isinstance(response, str):
break
else:
logger.warning(
f"[CollabLLMInteraction] got an invaild response {response} full_response {full_response}. \
Retrying..."
)
continue
else:
logger.warning(f"[CollabLLMInteraction] Keys {keys} do not match expected keys. Retrying...")
continue
self._instance_dict[instance_id]["response"] = response
logger.debug(f"[CollabLLMInteraction] User: {response}")
should_terminate_sequence = self.termination_signal in response
reward = 0.0
return should_terminate_sequence, response, reward, {}
async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
del self._instance_dict[instance_id]
def _parse_messages(self, messages, strip_sys_prompt=True):
if messages is None:
return ""
if strip_sys_prompt:
messages = [msg for msg in messages if msg["role"] != "system"]
messages = [remove_think_block(msg) for msg in messages]
chat = "\n".join(f"**{m['role'].capitalize()}**: {m['content']}" for m in messages)
return chat
def extract_json(s):
def convert_value(value):
true_values = {"true": True, "false": False, "null": None}
value_lower = value.lower()
if value_lower in true_values:
return true_values[value_lower]
try:
if "." in value or "e" in value.lower():
return float(value)
else:
return int(value)
except ValueError:
return value # Return as string if not a number
def parse_number(s, pos):
start = pos
while pos < len(s) and s[pos] in "-+0123456789.eE":
pos += 1
num_str = s[start:pos]
try:
if "." in num_str or "e" in num_str.lower():
return float(num_str), pos
else:
return int(num_str), pos
except ValueError:
logger.error(f"Invalid number at position {start}: {num_str}")
raise
def skip_whitespace(s, pos):
while pos < len(s) and s[pos] in " \t\n\r":
pos += 1
return pos
def parse_string(s, pos):
quote_char = s[pos]
assert quote_char in ('"', "'")
pos += 1
result = ""
while pos < len(s):
c = s[pos]
if c == "\\":
pos += 1
if pos >= len(s):
raise ValueError("Invalid escape sequence")
c = s[pos]
escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char}
result += escape_sequences.get(c, c)
elif c == quote_char:
pos += 1
# Attempt to convert to a number if possible
converted_value = convert_value(result)
return converted_value, pos
else:
result += c
pos += 1
raise ValueError("Unterminated string")
def parse_key(s, pos):
pos = skip_whitespace(s, pos)
if s[pos] in ('"', "'"):
key, pos = parse_string(s, pos)
return key, pos
else:
raise ValueError(f"Expected string for key at position {pos}")
def parse_object(s, pos):
obj = {}
assert s[pos] == "{"
pos += 1
pos = skip_whitespace(s, pos)
while pos < len(s) and s[pos] != "}":
pos = skip_whitespace(s, pos)
key, pos = parse_key(s, pos)
pos = skip_whitespace(s, pos)
if pos >= len(s) or s[pos] != ":":
raise ValueError(f'Expected ":" at position {pos}')
pos += 1
pos = skip_whitespace(s, pos)
value, pos = parse_value(s, pos)
obj[key] = value
pos = skip_whitespace(s, pos)
if pos < len(s) and s[pos] == ",":
pos += 1
pos = skip_whitespace(s, pos)
elif pos < len(s) and s[pos] == "}":
break
elif pos < len(s) and s[pos] != "}":
raise ValueError(f'Expected "," or "}}" at position {pos}')
if pos >= len(s) or s[pos] != "}":
raise ValueError(f'Expected "}}" at position {pos}')
pos += 1
return obj, pos
def parse_array(s, pos):
lst = []
assert s[pos] == "["
pos += 1
pos = skip_whitespace(s, pos)
while pos < len(s) and s[pos] != "]":
value, pos = parse_value(s, pos)
lst.append(value)
pos = skip_whitespace(s, pos)
if pos < len(s) and s[pos] == ",":
pos += 1
pos = skip_whitespace(s, pos)
elif pos < len(s) and s[pos] == "]":
break
elif pos < len(s) and s[pos] != "]":
raise ValueError(f'Expected "," or "]" at position {pos}')
if pos >= len(s) or s[pos] != "]":
raise ValueError(f'Expected "]" at position {pos}')
pos += 1
return lst, pos
def parse_triple_quoted_string(s, pos):
if s[pos : pos + 3] == "'''":
quote_str = "'''"
elif s[pos : pos + 3] == '"""':
quote_str = '"""'
else:
raise ValueError(f"Expected triple quotes at position {pos}")
pos += 3
result = ""
while pos < len(s):
if s[pos : pos + 3] == quote_str:
pos += 3
# Attempt to convert to a number if possible
converted_value = convert_value(result)
return converted_value, pos
else:
result += s[pos]
pos += 1
raise ValueError("Unterminated triple-quoted string")
def parse_value(s, pos):
pos = skip_whitespace(s, pos)
if pos >= len(s):
raise ValueError("Unexpected end of input")
if s[pos] == "{":
return parse_object(s, pos)
elif s[pos] == "[":
return parse_array(s, pos)
elif s[pos : pos + 3] in ("'''", '"""'):
return parse_triple_quoted_string(s, pos)
elif s[pos] in ('"', "'"):
return parse_string(s, pos)
elif s[pos : pos + 4].lower() == "true":
return True, pos + 4
elif s[pos : pos + 5].lower() == "false":
return False, pos + 5
elif s[pos : pos + 4].lower() == "null":
return None, pos + 4
elif s[pos] in "-+0123456789.":
return parse_number(s, pos)
else:
raise ValueError(f"Unexpected character at position {pos}: {s[pos]}")
json_start = s.index("{")
json_end = s.rfind("}")
s = s[json_start : json_end + 1]
s = s.strip()
result, pos = parse_value(s, 0)
pos = skip_whitespace(s, pos)
if pos != len(s):
raise ValueError(f"Unexpected content at position {pos}")
return result

View File

@ -226,6 +226,7 @@ actor_rollout_ref:
use_inference_chat_template: false
tokenization_sanity_check_mode: strict
format: hermes
num_repeat_rollouts: null
calculate_log_probs: false
agent:
_target_: verl.workers.config.AgentLoopConfig

View File

@ -213,6 +213,7 @@ actor_rollout_ref:
use_inference_chat_template: false
tokenization_sanity_check_mode: strict
format: hermes
num_repeat_rollouts: null
calculate_log_probs: false
agent:
_target_: verl.workers.config.AgentLoopConfig

View File

@ -180,6 +180,9 @@ multi_turn:
# Format of the multi-turn interaction. Options: hermes, llama3_json, ...
format: hermes
# Number of repeat rollouts for each interaction
num_repeat_rollouts: null
# support logging rollout prob for debugging purpose
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
calculate_log_probs: False

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain.
"""
import os
@ -312,6 +312,7 @@ class TaskRunner:
)
# Initialize the workers of the trainer.
trainer.init_workers()
# Start the training process.
trainer.fit()
@ -349,7 +350,6 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr
dataset_cls = DynamicGenDataset
print("Using DynamicGenDataset for data generation.")
else:
# Use the default RLHFDataset class if no custom class is specified
dataset_cls = RLHFDataset

View File

@ -229,6 +229,7 @@ def compute_advantage(
elif adv_estimator == AdvantageEstimator.GRPO:
# Initialize the mask for GRPO calculation
grpo_calculation_mask = data.batch["response_mask"]
# Call compute_grpo_outcome_advantage with parameters matching its definition
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=data.batch["token_level_rewards"],
@ -981,7 +982,6 @@ class RayPPOTrainer:
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
batch: DataProto = DataProto.from_single_dict(batch_dict)
# add uid to batch
@ -996,7 +996,6 @@ class RayPPOTrainer:
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
@ -1004,6 +1003,7 @@ class RayPPOTrainer:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
@ -1027,7 +1027,6 @@ class RayPPOTrainer:
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
@ -1109,7 +1108,6 @@ class RayPPOTrainer:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor

View File

@ -55,6 +55,7 @@ class MultiTurnConfig(BaseConfig):
use_inference_chat_template: bool = False
tokenization_sanity_check_mode: str = "strict"
format: str = "hermes"
num_repeat_rollouts: Optional[int] = None
@dataclass

View File

@ -14,6 +14,7 @@
from .registry import get_reward_manager_cls, register # noqa: I001
from .batch import BatchRewardManager
from .collabllm import CollabLLMRewardManager
from .dapo import DAPORewardManager
from .naive import NaiveRewardManager
from .prime import PrimeRewardManager
@ -21,6 +22,7 @@ from .prime import PrimeRewardManager
# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies
__all__ = [
"BatchRewardManager",
"CollabLLMRewardManager",
"DAPORewardManager",
"NaiveRewardManager",
"PrimeRewardManager",

View File

@ -0,0 +1,152 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Any, Callable, Optional
import torch
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
@register("collabllm")
class CollabLLMRewardManager(AbstractRewardManager):
"""
The Reward Manager used in https://github.com/Wuyxin/collabllm/
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
num_examine: int,
metric_weights: dict,
llm_judge_kwargs: dict,
reward_fn_key: str = "data_source",
compute_score: Optional[Callable] = None,
normalize_by_data_source=False,
) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or default_compute_score
self.reward_fn_key = reward_fn_key
self.metric_weights = metric_weights
self.llm_judge_kwargs = llm_judge_kwargs
self.normalize_by_data_source = normalize_by_data_source
self.metrics = list(self.metric_weights.keys())
# force CollabLLMAgentLoop to be registered
from recipe.collabllm.collabllm_agent_loop import CollabLLMAgentLoop # noqa
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
return {"reward_tensor": data.batch["rm_scores"]}
else:
return data.batch["rm_scores"]
# Use thread-compatible async loop management instead of asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self._compute_rewards_async(data, return_dict))
finally:
loop.close()
async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
# batched scoring
prompt_ids = data.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
data_source = data.non_tensor_batch["data_source"]
ground_truth = data.non_tensor_batch["ground_truth"]
extra_info = data.non_tensor_batch["extra_info"]
message_lst = data.non_tensor_batch["messages"]
# batch the messages into multiple
num_repeat_rollouts = len(message_lst[0]["messages"])
batch_size = len(data_source)
grouped_messages = [
[message_lst[i]["messages"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts)
]
# Flatten lists for all batch items across all rollouts
flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)]
if num_repeat_rollouts > 0:
tasks = [
self.compute_score(
flattened_data_sources[i],
flattened_messages[i],
flattened_ground_truths[i],
flattened_extra_infos[i],
self.metrics,
**self.llm_judge_kwargs,
)
for i in range(len(flattened_data_sources))
]
score_dicts = await asyncio.gather(*tasks)
# Aggregate scores for each metric across repeated rollouts
scores_by_metrics = {
metric: torch.stack([score_dict[metric] for score_dict in score_dicts])
.view(num_repeat_rollouts, -1)
.sum(dim=0)
for metric in self.metrics
}
# Apply metric-specific weights
weighted_scores_by_metrics = {
metric: torch.clamp(
scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts,
min=-1.0,
max=1.0,
)
for metric in self.metrics
}
# Compute mean of weighted scores for each metric
mean_weighted_scores_by_metrics = {
metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics
}
# Combine weighted scores from all metrics into a single tensor
scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0)
else:
score_dicts = []
scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device)
mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics}
print("Scores:", scores, mean_weighted_scores_by_metrics)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
for i in range(len(data)):
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
if return_dict:
return {"reward_tensor": reward_tensor}
else:
return reward_tensor

View File

@ -270,6 +270,7 @@ class SGLangRollout(BaseRollout):
self._function_call_parser,
) = self._initialize_tools(config, processing_class)
self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config)
# If turn on `free_cache_engine`, SGLang engine's KV cache
# will be freed after each `generate_sequences` call.
logger.info(
@ -287,7 +288,6 @@ class SGLangRollout(BaseRollout):
self._init_sampling_params(**kwargs)
self.processing_class = processing_class
try:
# This is when processing_class is a tokenizer
self.pad_token_id = self.processing_class.pad_token_id
@ -964,6 +964,9 @@ class SGLangRollout(BaseRollout):
):
_req.state = AsyncRolloutRequestStateEnum.INTERACTING
else:
# Add ending condition
finish_reason_type = FinishReasonTypeEnum.STOP
_req.state = AsyncRolloutRequestStateEnum.COMPLETED
break
elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING:
user_turns += 1
@ -980,11 +983,17 @@ class SGLangRollout(BaseRollout):
)
interaction = self.interaction_map[interaction_name]
should_terminate_sequence, content, reward, metrics = await interaction.generate_response(
_req.request_id, messages, **_req.interaction_kwargs
)
user_turn_rewards.append(reward)
if should_terminate_sequence:
# Add turn check
if (
should_terminate_sequence
or user_turns > self.config.multi_turn.max_user_turns
or current_turns > self.config.multi_turn.max_assistant_turns
):
finish_reason_type = FinishReasonTypeEnum.STOP
_req.state = AsyncRolloutRequestStateEnum.COMPLETED
break
@ -1013,6 +1022,7 @@ class SGLangRollout(BaseRollout):
tool_reward_scores = dict(tool_reward_scores)
all_rewards = {**tool_reward_scores, **{"user_turn_rewards": user_turn_rewards}}
_req.finalize(self.processing_class, all_rewards, finish_reason_type)
if self.config.calculate_log_probs:
debug_sampling_params = {**self.sampling_params}
debug_sampling_params["max_new_tokens"] = 0