mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[recipe] fix: move all collabllm files into recipe directory (#3706)
### What does this PR do? resolve issue https://github.com/volcengine/verl/issues/3606 1. move and register reward manager into custom_reward_function file 2. register agent loop in agent.yaml 3. move collabllm_interation.py into recipe ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test ``` (TaskRunner pid=52293) step:3 - global_seqlen/min:56551 - global_seqlen/max:94884 - global_seqlen/minmax_diff:38333 - global_seqlen/balanced_min:72054 - ``` ### API and Usage Example n/a ### Design & Code Changes n/a ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
@ -20,7 +20,7 @@ from typing import Any
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from recipe.collabllm.utils import is_valid_messages
|
from recipe.collabllm.utils import is_valid_messages
|
||||||
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register
|
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput
|
||||||
from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop
|
from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop
|
||||||
from verl.utils.rollout_trace import rollout_trace_op
|
from verl.utils.rollout_trace import rollout_trace_op
|
||||||
from verl.workers.rollout.schemas import Message
|
from verl.workers.rollout.schemas import Message
|
||||||
@ -29,7 +29,6 @@ logger = logging.getLogger(__file__)
|
|||||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||||
|
|
||||||
|
|
||||||
@register("collabllm_agent")
|
|
||||||
class CollabLLMAgentLoop(ToolAgentLoop):
|
class CollabLLMAgentLoop(ToolAgentLoop):
|
||||||
@rollout_trace_op
|
@rollout_trace_op
|
||||||
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
||||||
|
@ -21,10 +21,9 @@ from typing import Any, Optional
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from recipe.collabllm.utils import remove_think_block
|
from recipe.collabllm.utils import remove_think_block
|
||||||
|
from verl.interactions.base import BaseInteraction
|
||||||
from verl.utils.rollout_trace import rollout_trace_op
|
from verl.utils.rollout_trace import rollout_trace_op
|
||||||
|
|
||||||
from .base import BaseInteraction
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||||
|
|
2
recipe/collabllm/config/agent.yaml
Normal file
2
recipe/collabllm/config/agent.yaml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
- name: collabllm_agent
|
||||||
|
_target_: recipe.collabllm.collabllm_agent_loop.CollabLLMAgentLoop
|
@ -1,6 +1,6 @@
|
|||||||
interaction:
|
interaction:
|
||||||
- name: "collabllm"
|
- name: "collabllm"
|
||||||
class_name: "verl.interactions.collabllm_interation.CollabLLMInteraction"
|
class_name: "recipe.collabllm.collabllm_interation.CollabLLMInteraction"
|
||||||
config: {
|
config: {
|
||||||
"user_model": "gpt-4o-mini",
|
"user_model": "gpt-4o-mini",
|
||||||
"num_retries": 3,
|
"num_retries": 3,
|
||||||
|
@ -17,9 +17,18 @@ import asyncio
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from verl import DataProto
|
||||||
|
from verl.utils.reward_score import default_compute_score
|
||||||
|
from verl.workers.reward_manager import register
|
||||||
|
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||||
|
|
||||||
|
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
|
||||||
|
|
||||||
|
|
||||||
async def conversation_level_reward_func(
|
async def conversation_level_reward_func(
|
||||||
@ -93,3 +102,126 @@ async def conversation_level_reward_func(
|
|||||||
|
|
||||||
# Return dict with metric names as keys
|
# Return dict with metric names as keys
|
||||||
return {metric: torch.tensor(reward, dtype=torch.float32) for metric, reward in rewards.items()}
|
return {metric: torch.tensor(reward, dtype=torch.float32) for metric, reward in rewards.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@register("collabllm")
|
||||||
|
class CollabLLMRewardManager(AbstractRewardManager):
|
||||||
|
"""
|
||||||
|
The Reward Manager used in https://github.com/Wuyxin/collabllm/
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
num_examine: int,
|
||||||
|
metric_weights: dict,
|
||||||
|
llm_judge_kwargs: dict,
|
||||||
|
reward_fn_key: str = "data_source",
|
||||||
|
compute_score: Optional[Callable] = None,
|
||||||
|
normalize_by_data_source=False,
|
||||||
|
) -> None:
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
|
||||||
|
self.compute_score = compute_score or default_compute_score
|
||||||
|
self.reward_fn_key = reward_fn_key
|
||||||
|
|
||||||
|
self.metric_weights = metric_weights
|
||||||
|
self.llm_judge_kwargs = llm_judge_kwargs
|
||||||
|
self.normalize_by_data_source = normalize_by_data_source
|
||||||
|
|
||||||
|
self.metrics = list(self.metric_weights.keys())
|
||||||
|
|
||||||
|
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||||
|
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||||
|
if "rm_scores" in data.batch.keys():
|
||||||
|
if return_dict:
|
||||||
|
return {"reward_tensor": data.batch["rm_scores"]}
|
||||||
|
else:
|
||||||
|
return data.batch["rm_scores"]
|
||||||
|
# Use thread-compatible async loop management instead of asyncio.run()
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(self._compute_rewards_async(data, return_dict))
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||||
|
# batched scoring
|
||||||
|
prompt_ids = data.batch["prompts"]
|
||||||
|
prompt_length = prompt_ids.shape[-1]
|
||||||
|
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
|
||||||
|
|
||||||
|
data_source = data.non_tensor_batch["data_source"]
|
||||||
|
ground_truth = data.non_tensor_batch["ground_truth"]
|
||||||
|
extra_info = data.non_tensor_batch["extra_info"]
|
||||||
|
message_lst = data.non_tensor_batch["messages"]
|
||||||
|
|
||||||
|
# batch the messages into multiple
|
||||||
|
num_repeat_rollouts = len(message_lst[0]["messages"])
|
||||||
|
batch_size = len(data_source)
|
||||||
|
|
||||||
|
grouped_messages = [
|
||||||
|
[message_lst[i]["messages"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Flatten lists for all batch items across all rollouts
|
||||||
|
flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||||
|
flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||||
|
flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||||
|
flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)]
|
||||||
|
|
||||||
|
if num_repeat_rollouts > 0:
|
||||||
|
tasks = [
|
||||||
|
self.compute_score(
|
||||||
|
flattened_data_sources[i],
|
||||||
|
flattened_messages[i],
|
||||||
|
flattened_ground_truths[i],
|
||||||
|
flattened_extra_infos[i],
|
||||||
|
self.metrics,
|
||||||
|
**self.llm_judge_kwargs,
|
||||||
|
)
|
||||||
|
for i in range(len(flattened_data_sources))
|
||||||
|
]
|
||||||
|
score_dicts = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Aggregate scores for each metric across repeated rollouts
|
||||||
|
scores_by_metrics = {
|
||||||
|
metric: torch.stack([score_dict[metric] for score_dict in score_dicts])
|
||||||
|
.view(num_repeat_rollouts, -1)
|
||||||
|
.sum(dim=0)
|
||||||
|
for metric in self.metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply metric-specific weights
|
||||||
|
weighted_scores_by_metrics = {
|
||||||
|
metric: torch.clamp(
|
||||||
|
scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts,
|
||||||
|
min=-1.0,
|
||||||
|
max=1.0,
|
||||||
|
)
|
||||||
|
for metric in self.metrics
|
||||||
|
}
|
||||||
|
# Compute mean of weighted scores for each metric
|
||||||
|
mean_weighted_scores_by_metrics = {
|
||||||
|
metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
# Combine weighted scores from all metrics into a single tensor
|
||||||
|
scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0)
|
||||||
|
else:
|
||||||
|
score_dicts = []
|
||||||
|
scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device)
|
||||||
|
mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics}
|
||||||
|
|
||||||
|
print("Scores:", scores, mean_weighted_scores_by_metrics)
|
||||||
|
|
||||||
|
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
|
||||||
|
|
||||||
|
for i in range(len(data)):
|
||||||
|
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
|
||||||
|
|
||||||
|
if return_dict:
|
||||||
|
return {"reward_tensor": reward_tensor}
|
||||||
|
else:
|
||||||
|
return reward_tensor
|
||||||
|
@ -13,6 +13,7 @@ fi
|
|||||||
|
|
||||||
DATASET=math-hard-large
|
DATASET=math-hard-large
|
||||||
PROJECT_DIR="$(pwd)"
|
PROJECT_DIR="$(pwd)"
|
||||||
|
AGENTLOOP_CONFIG_PATH="$PROJECT_DIR/recipe/collabllm/config/agent.yaml"
|
||||||
|
|
||||||
|
|
||||||
python3 -m verl.trainer.main_ppo \
|
python3 -m verl.trainer.main_ppo \
|
||||||
@ -56,7 +57,7 @@ python3 -m verl.trainer.main_ppo \
|
|||||||
actor_rollout_ref.rollout.multi_turn.max_user_turns=2 \
|
actor_rollout_ref.rollout.multi_turn.max_user_turns=2 \
|
||||||
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3 \
|
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3 \
|
||||||
actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts=3 \
|
actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts=3 \
|
||||||
actor_rollout_ref.rollout.trace.token2text=True \
|
actor_rollout_ref.rollout.agent.agent_loop_config_path=$AGENTLOOP_CONFIG_PATH \
|
||||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||||
algorithm.use_kl_in_reward=False \
|
algorithm.use_kl_in_reward=False \
|
||||||
trainer.critic_warmup=0 \
|
trainer.critic_warmup=0 \
|
||||||
|
@ -109,6 +109,11 @@ def load_reward_manager(
|
|||||||
An instance of the specified reward manager class.
|
An instance of the specified reward manager class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Try to get a custom reward function based on the configuration
|
||||||
|
# user defined reward manager can be registered in custom_reward_fn
|
||||||
|
compute_score = get_custom_reward_fn(config)
|
||||||
|
final_compute_score = compute_score
|
||||||
|
|
||||||
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
|
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
|
||||||
# naive: NaiveRewardManager
|
# naive: NaiveRewardManager
|
||||||
# prime: PrimeRewardManager
|
# prime: PrimeRewardManager
|
||||||
@ -120,10 +125,6 @@ def load_reward_manager(
|
|||||||
reward_manager_name = config.reward_model.get("reward_manager", "naive")
|
reward_manager_name = config.reward_model.get("reward_manager", "naive")
|
||||||
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
|
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
|
||||||
|
|
||||||
# Try to get a custom reward function based on the configuration
|
|
||||||
compute_score = get_custom_reward_fn(config)
|
|
||||||
final_compute_score = compute_score
|
|
||||||
|
|
||||||
if compute_score is None:
|
if compute_score is None:
|
||||||
sandbox_config = config.reward_model.get("sandbox_fusion")
|
sandbox_config = config.reward_model.get("sandbox_fusion")
|
||||||
sandbox_url = sandbox_config.get("url") if sandbox_config else None
|
sandbox_url = sandbox_config.get("url") if sandbox_config else None
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
from .registry import get_reward_manager_cls, register # noqa: I001
|
from .registry import get_reward_manager_cls, register # noqa: I001
|
||||||
from .batch import BatchRewardManager
|
from .batch import BatchRewardManager
|
||||||
from .collabllm import CollabLLMRewardManager
|
|
||||||
from .dapo import DAPORewardManager
|
from .dapo import DAPORewardManager
|
||||||
from .naive import NaiveRewardManager
|
from .naive import NaiveRewardManager
|
||||||
from .prime import PrimeRewardManager
|
from .prime import PrimeRewardManager
|
||||||
@ -22,7 +21,6 @@ from .prime import PrimeRewardManager
|
|||||||
# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies
|
# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatchRewardManager",
|
"BatchRewardManager",
|
||||||
"CollabLLMRewardManager",
|
|
||||||
"DAPORewardManager",
|
"DAPORewardManager",
|
||||||
"NaiveRewardManager",
|
"NaiveRewardManager",
|
||||||
"PrimeRewardManager",
|
"PrimeRewardManager",
|
||||||
|
@ -1,152 +0,0 @@
|
|||||||
# Copyright 2025 CollabLLM team and/or its affiliates
|
|
||||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
|
||||||
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Any, Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from verl import DataProto
|
|
||||||
from verl.utils.reward_score import default_compute_score
|
|
||||||
from verl.workers.reward_manager import register
|
|
||||||
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
|
||||||
|
|
||||||
TERMINATION_SIGNAL = "[[TERMINATE CHAT]]"
|
|
||||||
|
|
||||||
|
|
||||||
@register("collabllm")
|
|
||||||
class CollabLLMRewardManager(AbstractRewardManager):
|
|
||||||
"""
|
|
||||||
The Reward Manager used in https://github.com/Wuyxin/collabllm/
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
num_examine: int,
|
|
||||||
metric_weights: dict,
|
|
||||||
llm_judge_kwargs: dict,
|
|
||||||
reward_fn_key: str = "data_source",
|
|
||||||
compute_score: Optional[Callable] = None,
|
|
||||||
normalize_by_data_source=False,
|
|
||||||
) -> None:
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
|
|
||||||
self.compute_score = compute_score or default_compute_score
|
|
||||||
self.reward_fn_key = reward_fn_key
|
|
||||||
|
|
||||||
self.metric_weights = metric_weights
|
|
||||||
self.llm_judge_kwargs = llm_judge_kwargs
|
|
||||||
self.normalize_by_data_source = normalize_by_data_source
|
|
||||||
|
|
||||||
self.metrics = list(self.metric_weights.keys())
|
|
||||||
# force CollabLLMAgentLoop to be registered
|
|
||||||
from recipe.collabllm.collabllm_agent_loop import CollabLLMAgentLoop # noqa
|
|
||||||
|
|
||||||
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
|
||||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
|
||||||
if "rm_scores" in data.batch.keys():
|
|
||||||
if return_dict:
|
|
||||||
return {"reward_tensor": data.batch["rm_scores"]}
|
|
||||||
else:
|
|
||||||
return data.batch["rm_scores"]
|
|
||||||
# Use thread-compatible async loop management instead of asyncio.run()
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
return loop.run_until_complete(self._compute_rewards_async(data, return_dict))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
|
||||||
# batched scoring
|
|
||||||
prompt_ids = data.batch["prompts"]
|
|
||||||
prompt_length = prompt_ids.shape[-1]
|
|
||||||
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
|
|
||||||
|
|
||||||
data_source = data.non_tensor_batch["data_source"]
|
|
||||||
ground_truth = data.non_tensor_batch["ground_truth"]
|
|
||||||
extra_info = data.non_tensor_batch["extra_info"]
|
|
||||||
message_lst = data.non_tensor_batch["messages"]
|
|
||||||
|
|
||||||
# batch the messages into multiple
|
|
||||||
num_repeat_rollouts = len(message_lst[0]["messages"])
|
|
||||||
batch_size = len(data_source)
|
|
||||||
|
|
||||||
grouped_messages = [
|
|
||||||
[message_lst[i]["messages"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Flatten lists for all batch items across all rollouts
|
|
||||||
flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
|
||||||
flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
|
||||||
flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]
|
|
||||||
flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)]
|
|
||||||
|
|
||||||
if num_repeat_rollouts > 0:
|
|
||||||
tasks = [
|
|
||||||
self.compute_score(
|
|
||||||
flattened_data_sources[i],
|
|
||||||
flattened_messages[i],
|
|
||||||
flattened_ground_truths[i],
|
|
||||||
flattened_extra_infos[i],
|
|
||||||
self.metrics,
|
|
||||||
**self.llm_judge_kwargs,
|
|
||||||
)
|
|
||||||
for i in range(len(flattened_data_sources))
|
|
||||||
]
|
|
||||||
score_dicts = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
# Aggregate scores for each metric across repeated rollouts
|
|
||||||
scores_by_metrics = {
|
|
||||||
metric: torch.stack([score_dict[metric] for score_dict in score_dicts])
|
|
||||||
.view(num_repeat_rollouts, -1)
|
|
||||||
.sum(dim=0)
|
|
||||||
for metric in self.metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
# Apply metric-specific weights
|
|
||||||
weighted_scores_by_metrics = {
|
|
||||||
metric: torch.clamp(
|
|
||||||
scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts,
|
|
||||||
min=-1.0,
|
|
||||||
max=1.0,
|
|
||||||
)
|
|
||||||
for metric in self.metrics
|
|
||||||
}
|
|
||||||
# Compute mean of weighted scores for each metric
|
|
||||||
mean_weighted_scores_by_metrics = {
|
|
||||||
metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
# Combine weighted scores from all metrics into a single tensor
|
|
||||||
scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0)
|
|
||||||
else:
|
|
||||||
score_dicts = []
|
|
||||||
scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device)
|
|
||||||
mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics}
|
|
||||||
|
|
||||||
print("Scores:", scores, mean_weighted_scores_by_metrics)
|
|
||||||
|
|
||||||
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
|
|
||||||
|
|
||||||
for i in range(len(data)):
|
|
||||||
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
|
|
||||||
|
|
||||||
if return_dict:
|
|
||||||
return {"reward_tensor": reward_tensor}
|
|
||||||
else:
|
|
||||||
return reward_tensor
|
|
Reference in New Issue
Block a user