mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 05:33:49 +08:00
[misc] refactor: Add AbstractRewardManager
abstract class (#2763)
### What does this PR do? Adds a new `AbstractRewardManager` class to codify the interface for a reward manager. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
committed by
GitHub
parent
ae3506dd33
commit
a24241092d
@ -81,6 +81,9 @@ ignore_errors = true
|
||||
module = [
|
||||
"verl.trainer.config.algorithm",
|
||||
"verl.trainer.ppo.core_algos",
|
||||
"verl.trainer.ppo.reward",
|
||||
"verl.workers.reward_manager",
|
||||
"verl.workers.reward_manager.*",
|
||||
]
|
||||
ignore_errors = false
|
||||
|
||||
|
@ -12,14 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.reward_score import default_compute_score
|
||||
from verl.workers.reward_manager import get_reward_manager_cls
|
||||
from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn
|
||||
|
||||
|
||||
def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
|
||||
@ -31,7 +39,7 @@ def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
|
||||
return raw_fn(*args, **merged_kwargs)
|
||||
|
||||
|
||||
def get_custom_reward_fn(config):
|
||||
def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:
|
||||
"""Load and return a custom reward function from external file.
|
||||
|
||||
Dynamically imports a reward function from a specified file path and wraps
|
||||
@ -50,8 +58,6 @@ def get_custom_reward_fn(config):
|
||||
RuntimeError: If there's an error loading the module from file.
|
||||
AttributeError: If the specified function name isn't found in the module.
|
||||
"""
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
reward_fn_config = config.get("custom_reward_function") or {}
|
||||
file_path = reward_fn_config.get("path")
|
||||
@ -62,14 +68,17 @@ def get_custom_reward_fn(config):
|
||||
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
|
||||
|
||||
spec = importlib.util.spec_from_file_location("custom_module", file_path)
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
try:
|
||||
sys.modules["custom_module"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
|
||||
|
||||
function_name = reward_fn_config.get("name")
|
||||
assert function_name is not None
|
||||
if not hasattr(module, function_name):
|
||||
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
|
||||
|
||||
@ -81,7 +90,9 @@ def get_custom_reward_fn(config):
|
||||
return partial(_call_with_kwargs, raw_fn, reward_kwargs)
|
||||
|
||||
|
||||
def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
|
||||
def load_reward_manager(
|
||||
config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
|
||||
) -> AbstractRewardManager:
|
||||
"""
|
||||
Load and initialize a reward manager based on the configuration.
|
||||
|
||||
@ -94,7 +105,6 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
|
||||
Returns:
|
||||
An instance of the specified reward manager class.
|
||||
"""
|
||||
from verl.workers.reward_manager import get_reward_manager_cls
|
||||
|
||||
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
|
||||
# naive: NaiveRewardManager
|
||||
@ -138,7 +148,7 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
|
||||
)
|
||||
|
||||
|
||||
def compute_reward(data: DataProto, reward_fn):
|
||||
def compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||
"""
|
||||
Compute reward for a batch of data.
|
||||
Args:
|
||||
@ -169,7 +179,6 @@ def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn
|
||||
assert config is not None and tokenizer is not None, (
|
||||
"config and tokenizer must not be None when reward_fn is None"
|
||||
)
|
||||
import warnings
|
||||
|
||||
warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2)
|
||||
reward_fn = load_reward_manager(
|
||||
|
45
verl/workers/reward_manager/abstract.py
Normal file
45
verl/workers/reward_manager/abstract.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2023-2025 SGLang Team
|
||||
# Copyright Amazon.com, Inc. or its affiliates.
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from verl.protocol import DataProto
|
||||
|
||||
RawRewardFn = Callable[..., Any]
|
||||
|
||||
|
||||
class AbstractRewardManager(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
num_examine: int,
|
||||
compute_score: RawRewardFn | None,
|
||||
reward_fn_key: str = "data_source",
|
||||
**kwargs: Any,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
data: DataProto,
|
||||
return_dict: bool = False,
|
||||
) -> torch.Tensor | dict[str, Any]:
|
||||
pass
|
@ -13,15 +13,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from verl import DataProto
|
||||
from verl.workers.reward_manager import register
|
||||
from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn
|
||||
|
||||
|
||||
@register("batch")
|
||||
class BatchRewardManager:
|
||||
class BatchRewardManager(AbstractRewardManager):
|
||||
"""
|
||||
A batch reward manager that computes rewards for a batch of data.
|
||||
|
||||
@ -33,7 +35,9 @@ class BatchRewardManager:
|
||||
reward_kwargs (dict): The keyword arguments to pass to the reward function.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs):
|
||||
def __init__(
|
||||
self, tokenizer, num_examine, compute_score: RawRewardFn, reward_fn_key="data_source", **reward_kwargs
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.num_examine = num_examine
|
||||
self.compute_score = compute_score
|
||||
@ -69,7 +73,7 @@ class BatchRewardManager:
|
||||
|
||||
return scores
|
||||
|
||||
def __call__(self, data: DataProto, return_dict=False):
|
||||
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||
if "rm_scores" in data.batch.keys():
|
||||
if return_dict:
|
||||
@ -87,7 +91,7 @@ class BatchRewardManager:
|
||||
|
||||
scores = self.verify(data)
|
||||
rewards = []
|
||||
already_printed = {}
|
||||
already_printed: dict[str, Any] = {}
|
||||
|
||||
for i in range(len(data)):
|
||||
length = valid_response_lengths[i].item()
|
||||
|
@ -19,10 +19,11 @@ import torch
|
||||
from verl import DataProto
|
||||
from verl.utils.reward_score import default_compute_score
|
||||
from verl.workers.reward_manager import register
|
||||
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||
|
||||
|
||||
@register("dapo")
|
||||
class DAPORewardManager:
|
||||
class DAPORewardManager(AbstractRewardManager):
|
||||
"""The reward manager."""
|
||||
|
||||
def __init__(
|
||||
|
@ -13,16 +13,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.reward_score import default_compute_score
|
||||
from verl.workers.reward_manager import register
|
||||
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||
|
||||
|
||||
@register("naive")
|
||||
class NaiveRewardManager:
|
||||
class NaiveRewardManager(AbstractRewardManager):
|
||||
"""The reward manager."""
|
||||
|
||||
def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None:
|
||||
@ -41,7 +43,7 @@ class NaiveRewardManager:
|
||||
self.compute_score = compute_score or default_compute_score
|
||||
self.reward_fn_key = reward_fn_key # Store the key for accessing the data source
|
||||
|
||||
def __call__(self, data: DataProto, return_dict=False):
|
||||
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||
"""We will expand this function gradually based on the available datasets"""
|
||||
|
||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||
|
@ -15,7 +15,7 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@ -24,6 +24,7 @@ from transformers import PreTrainedTokenizer
|
||||
from verl import DataProto
|
||||
from verl.utils.reward_score import default_compute_score
|
||||
from verl.workers.reward_manager import register
|
||||
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||
|
||||
|
||||
async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):
|
||||
@ -98,7 +99,7 @@ def run_reward_scoring(evaluation_func, completions, references, tasks, extra_in
|
||||
|
||||
|
||||
@register("prime")
|
||||
class PrimeRewardManager:
|
||||
class PrimeRewardManager(AbstractRewardManager):
|
||||
"""
|
||||
The Reward Manager used in https://github.com/PRIME-RL/PRIME
|
||||
"""
|
||||
@ -147,7 +148,7 @@ class PrimeRewardManager:
|
||||
data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)
|
||||
return scores
|
||||
|
||||
def __call__(self, data: DataProto, return_dict: bool = False):
|
||||
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||
"""We will expand this function gradually based on the available datasets"""
|
||||
|
||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||
|
@ -12,12 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||
|
||||
__all__ = ["register", "get_reward_manager_cls"]
|
||||
|
||||
REWARD_MANAGER_REGISTRY = {}
|
||||
REWARD_MANAGER_REGISTRY: dict[str, type[AbstractRewardManager]] = {}
|
||||
|
||||
|
||||
def register(name):
|
||||
def register(name: str) -> Callable[[type[AbstractRewardManager]], type[AbstractRewardManager]]:
|
||||
"""Decorator to register a reward manager class with a given name.
|
||||
|
||||
Args:
|
||||
@ -25,7 +29,7 @@ def register(name):
|
||||
The name of the reward manager.
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
def decorator(cls: type[AbstractRewardManager]) -> type[AbstractRewardManager]:
|
||||
if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls:
|
||||
raise ValueError(
|
||||
f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}"
|
||||
@ -36,7 +40,7 @@ def register(name):
|
||||
return decorator
|
||||
|
||||
|
||||
def get_reward_manager_cls(name):
|
||||
def get_reward_manager_cls(name: str) -> type[AbstractRewardManager]:
|
||||
"""Get the reward manager class with a given name.
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user