[recipe] fix: move all collabllm files into recipe directory (#3706)

### What does this PR do?

resolve issue https://github.com/volcengine/verl/issues/3606

1. move and register reward manager into custom_reward_function file
2. register agent loop in agent.yaml
3. move collabllm_interation.py into recipe



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test
```
(TaskRunner pid=52293) step:3 - global_seqlen/min:56551 - global_seqlen/max:94884 - global_seqlen/minmax_diff:38333 - global_seqlen/balanced_min:72054 -
```

### API and Usage Example

n/a

### Design & Code Changes

n/a

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
OC
2025-10-09 18:50:37 +08:00
committed by GitHub
parent 23877bcc64
commit cf619d68d4
9 changed files with 144 additions and 164 deletions

View File

@ -20,7 +20,7 @@ from typing import Any
from uuid import uuid4
from recipe.collabllm.utils import is_valid_messages
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput
from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop
from verl.utils.rollout_trace import rollout_trace_op
from verl.workers.rollout.schemas import Message
@ -29,7 +29,6 @@ logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@register("collabllm_agent")
class CollabLLMAgentLoop(ToolAgentLoop):
@rollout_trace_op
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:

View File

@ -21,10 +21,9 @@ from typing import Any, Optional
from uuid import uuid4
from recipe.collabllm.utils import remove_think_block
from verl.interactions.base import BaseInteraction
from verl.utils.rollout_trace import rollout_trace_op
from .base import BaseInteraction
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

View File

@ -0,0 +1,2 @@
- name: collabllm_agent
_target_: recipe.collabllm.collabllm_agent_loop.CollabLLMAgentLoop

View File

@ -1,6 +1,6 @@
interaction:
- name: "collabllm"
class_name: "verl.interactions.collabllm_interation.CollabLLMInteraction"
class_name: "recipe.collabllm.collabllm_interation.CollabLLMInteraction"
config: {
"user_model": "gpt-4o-mini",
"num_retries": 3,

View File

@ -17,9 +17,18 @@ import asyncio
import importlib.util
import os
import sys
from typing import Any, Callable, Optional
import litellm
import torch
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
async def conversation_level_reward_func(
@ -93,3 +102,126 @@ async def conversation_level_reward_func(
# Return dict with metric names as keys
return {metric: torch.tensor(reward, dtype=torch.float32) for metric, reward in rewards.items()}
@register("collabllm")
class CollabLLMRewardManager(AbstractRewardManager):
"""
The Reward Manager used in https://github.com/Wuyxin/collabllm/
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
num_examine: int,
metric_weights: dict,
llm_judge_kwargs: dict,
reward_fn_key: str = "data_source",
compute_score: Optional[Callable] = None,
normalize_by_data_source=False,
) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or default_compute_score
self.reward_fn_key = reward_fn_key
self.metric_weights = metric_weights
self.llm_judge_kwargs = llm_judge_kwargs
self.normalize_by_data_source = normalize_by_data_source
self.metrics = list(self.metric_weights.keys())
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
return {"reward_tensor": data.batch["rm_scores"]}
else:
return data.batch["rm_scores"]
# Use thread-compatible async loop management instead of asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self._compute_rewards_async(data, return_dict))
finally:
loop.close()
async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
# batched scoring
prompt_ids = data.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
data_source = data.non_tensor_batch["data_source"]
ground_truth = data.non_tensor_batch["ground_truth"]
extra_info = data.non_tensor_batch["extra_info"]
message_lst = data.non_tensor_batch["messages"]
# batch the messages into multiple
num_repeat_rollouts = len(message_lst[0]["messages"])
batch_size = len(data_source)
grouped_messages = [
[message_lst[i]["messages"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts)
]
# Flatten lists for all batch items across all rollouts
flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)]
if num_repeat_rollouts > 0:
tasks = [
self.compute_score(
flattened_data_sources[i],
flattened_messages[i],
flattened_ground_truths[i],
flattened_extra_infos[i],
self.metrics,
**self.llm_judge_kwargs,
)
for i in range(len(flattened_data_sources))
]
score_dicts = await asyncio.gather(*tasks)
# Aggregate scores for each metric across repeated rollouts
scores_by_metrics = {
metric: torch.stack([score_dict[metric] for score_dict in score_dicts])
.view(num_repeat_rollouts, -1)
.sum(dim=0)
for metric in self.metrics
}
# Apply metric-specific weights
weighted_scores_by_metrics = {
metric: torch.clamp(
scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts,
min=-1.0,
max=1.0,
)
for metric in self.metrics
}
# Compute mean of weighted scores for each metric
mean_weighted_scores_by_metrics = {
metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics
}
# Combine weighted scores from all metrics into a single tensor
scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0)
else:
score_dicts = []
scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device)
mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics}
print("Scores:", scores, mean_weighted_scores_by_metrics)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
for i in range(len(data)):
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
if return_dict:
return {"reward_tensor": reward_tensor}
else:
return reward_tensor

View File

@ -13,6 +13,7 @@ fi
DATASET=math-hard-large
PROJECT_DIR="$(pwd)"
AGENTLOOP_CONFIG_PATH="$PROJECT_DIR/recipe/collabllm/config/agent.yaml"
python3 -m verl.trainer.main_ppo \
@ -56,7 +57,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.multi_turn.max_user_turns=2 \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3 \
actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts=3 \
actor_rollout_ref.rollout.trace.token2text=True \
actor_rollout_ref.rollout.agent.agent_loop_config_path=$AGENTLOOP_CONFIG_PATH \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \

View File

@ -109,6 +109,11 @@ def load_reward_manager(
An instance of the specified reward manager class.
"""
# Try to get a custom reward function based on the configuration
# user defined reward manager can be registered in custom_reward_fn
compute_score = get_custom_reward_fn(config)
final_compute_score = compute_score
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
# naive: NaiveRewardManager
# prime: PrimeRewardManager
@ -120,10 +125,6 @@ def load_reward_manager(
reward_manager_name = config.reward_model.get("reward_manager", "naive")
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
# Try to get a custom reward function based on the configuration
compute_score = get_custom_reward_fn(config)
final_compute_score = compute_score
if compute_score is None:
sandbox_config = config.reward_model.get("sandbox_fusion")
sandbox_url = sandbox_config.get("url") if sandbox_config else None

View File

@ -14,7 +14,6 @@
from .registry import get_reward_manager_cls, register # noqa: I001
from .batch import BatchRewardManager
from .collabllm import CollabLLMRewardManager
from .dapo import DAPORewardManager
from .naive import NaiveRewardManager
from .prime import PrimeRewardManager
@ -22,7 +21,6 @@ from .prime import PrimeRewardManager
# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies
__all__ = [
"BatchRewardManager",
"CollabLLMRewardManager",
"DAPORewardManager",
"NaiveRewardManager",
"PrimeRewardManager",

View File

@ -1,152 +0,0 @@
# Copyright 2025 CollabLLM team and/or its affiliates
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Any, Callable, Optional
import torch
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
@register("collabllm")
class CollabLLMRewardManager(AbstractRewardManager):
"""
The Reward Manager used in https://github.com/Wuyxin/collabllm/
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
num_examine: int,
metric_weights: dict,
llm_judge_kwargs: dict,
reward_fn_key: str = "data_source",
compute_score: Optional[Callable] = None,
normalize_by_data_source=False,
) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or default_compute_score
self.reward_fn_key = reward_fn_key
self.metric_weights = metric_weights
self.llm_judge_kwargs = llm_judge_kwargs
self.normalize_by_data_source = normalize_by_data_source
self.metrics = list(self.metric_weights.keys())
# force CollabLLMAgentLoop to be registered
from recipe.collabllm.collabllm_agent_loop import CollabLLMAgentLoop # noqa
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
return {"reward_tensor": data.batch["rm_scores"]}
else:
return data.batch["rm_scores"]
# Use thread-compatible async loop management instead of asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self._compute_rewards_async(data, return_dict))
finally:
loop.close()
async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
# batched scoring
prompt_ids = data.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
data_source = data.non_tensor_batch["data_source"]
ground_truth = data.non_tensor_batch["ground_truth"]
extra_info = data.non_tensor_batch["extra_info"]
message_lst = data.non_tensor_batch["messages"]
# batch the messages into multiple
num_repeat_rollouts = len(message_lst[0]["messages"])
batch_size = len(data_source)
grouped_messages = [
[message_lst[i]["messages"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts)
]
# Flatten lists for all batch items across all rollouts
flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)]
if num_repeat_rollouts > 0:
tasks = [
self.compute_score(
flattened_data_sources[i],
flattened_messages[i],
flattened_ground_truths[i],
flattened_extra_infos[i],
self.metrics,
**self.llm_judge_kwargs,
)
for i in range(len(flattened_data_sources))
]
score_dicts = await asyncio.gather(*tasks)
# Aggregate scores for each metric across repeated rollouts
scores_by_metrics = {
metric: torch.stack([score_dict[metric] for score_dict in score_dicts])
.view(num_repeat_rollouts, -1)
.sum(dim=0)
for metric in self.metrics
}
# Apply metric-specific weights
weighted_scores_by_metrics = {
metric: torch.clamp(
scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts,
min=-1.0,
max=1.0,
)
for metric in self.metrics
}
# Compute mean of weighted scores for each metric
mean_weighted_scores_by_metrics = {
metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics
}
# Combine weighted scores from all metrics into a single tensor
scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0)
else:
score_dicts = []
scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device)
mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics}
print("Scores:", scores, mean_weighted_scores_by_metrics)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
for i in range(len(data)):
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
if return_dict:
return {"reward_tensor": reward_tensor}
else:
return reward_tensor