[misc] refactor: Add AbstractRewardManager abstract class (#2763)

### What does this PR do?

Adds a new `AbstractRewardManager` class to codify the interface for a
reward manager.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
Frederick Robinson
2025-08-02 16:39:58 -07:00
committed by GitHub
parent ae3506dd33
commit a24241092d
8 changed files with 90 additions and 21 deletions

View File

@ -81,6 +81,9 @@ ignore_errors = true
module = [
"verl.trainer.config.algorithm",
"verl.trainer.ppo.core_algos",
"verl.trainer.ppo.reward",
"verl.workers.reward_manager",
"verl.workers.reward_manager.*",
]
ignore_errors = false

View File

@ -12,14 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import multiprocessing
import os
import sys
import warnings
from functools import partial
from typing import Any, Optional
import ray
import torch
from omegaconf import DictConfig
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import get_reward_manager_cls
from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn
def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
@ -31,7 +39,7 @@ def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
return raw_fn(*args, **merged_kwargs)
def get_custom_reward_fn(config):
def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:
"""Load and return a custom reward function from external file.
Dynamically imports a reward function from a specified file path and wraps
@ -50,8 +58,6 @@ def get_custom_reward_fn(config):
RuntimeError: If there's an error loading the module from file.
AttributeError: If the specified function name isn't found in the module.
"""
import importlib.util
import sys
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
@ -62,14 +68,17 @@ def get_custom_reward_fn(config):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_module"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
function_name = reward_fn_config.get("name")
assert function_name is not None
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
@ -81,7 +90,9 @@ def get_custom_reward_fn(config):
return partial(_call_with_kwargs, raw_fn, reward_kwargs)
def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
def load_reward_manager(
config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
) -> AbstractRewardManager:
"""
Load and initialize a reward manager based on the configuration.
@ -94,7 +105,6 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
Returns:
An instance of the specified reward manager class.
"""
from verl.workers.reward_manager import get_reward_manager_cls
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
# naive: NaiveRewardManager
@ -138,7 +148,7 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
)
def compute_reward(data: DataProto, reward_fn):
def compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]:
"""
Compute reward for a batch of data.
Args:
@ -169,7 +179,6 @@ def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn
assert config is not None and tokenizer is not None, (
"config and tokenizer must not be None when reward_fn is None"
)
import warnings
warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2)
reward_fn = load_reward_manager(

View File

@ -0,0 +1,45 @@
# Copyright 2023-2025 SGLang Team
# Copyright Amazon.com, Inc. or its affiliates.
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Callable
import torch
from verl.protocol import DataProto
RawRewardFn = Callable[..., Any]
class AbstractRewardManager(ABC):
@abstractmethod
def __init__(
self,
tokenizer: Any,
num_examine: int,
compute_score: RawRewardFn | None,
reward_fn_key: str = "data_source",
**kwargs: Any,
):
pass
@abstractmethod
def __call__(
self,
data: DataProto,
return_dict: bool = False,
) -> torch.Tensor | dict[str, Any]:
pass

View File

@ -13,15 +13,17 @@
# limitations under the License.
from collections import defaultdict
from typing import Any
import torch
from verl import DataProto
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn
@register("batch")
class BatchRewardManager:
class BatchRewardManager(AbstractRewardManager):
"""
A batch reward manager that computes rewards for a batch of data.
@ -33,7 +35,9 @@ class BatchRewardManager:
reward_kwargs (dict): The keyword arguments to pass to the reward function.
"""
def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs):
def __init__(
self, tokenizer, num_examine, compute_score: RawRewardFn, reward_fn_key="data_source", **reward_kwargs
):
self.tokenizer = tokenizer
self.num_examine = num_examine
self.compute_score = compute_score
@ -69,7 +73,7 @@ class BatchRewardManager:
return scores
def __call__(self, data: DataProto, return_dict=False):
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
@ -87,7 +91,7 @@ class BatchRewardManager:
scores = self.verify(data)
rewards = []
already_printed = {}
already_printed: dict[str, Any] = {}
for i in range(len(data)):
length = valid_response_lengths[i].item()

View File

@ -19,10 +19,11 @@ import torch
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
@register("dapo")
class DAPORewardManager:
class DAPORewardManager(AbstractRewardManager):
"""The reward manager."""
def __init__(

View File

@ -13,16 +13,18 @@
# limitations under the License.
from collections import defaultdict
from typing import Any
import torch
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
@register("naive")
class NaiveRewardManager:
class NaiveRewardManager(AbstractRewardManager):
"""The reward manager."""
def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None:
@ -41,7 +43,7 @@ class NaiveRewardManager:
self.compute_score = compute_score or default_compute_score
self.reward_fn_key = reward_fn_key # Store the key for accessing the data source
def __call__(self, data: DataProto, return_dict=False):
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn

View File

@ -15,7 +15,7 @@
import asyncio
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from typing import Callable, Optional
from typing import Any, Callable, Optional
import psutil
import torch
@ -24,6 +24,7 @@ from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):
@ -98,7 +99,7 @@ def run_reward_scoring(evaluation_func, completions, references, tasks, extra_in
@register("prime")
class PrimeRewardManager:
class PrimeRewardManager(AbstractRewardManager):
"""
The Reward Manager used in https://github.com/PRIME-RL/PRIME
"""
@ -147,7 +148,7 @@ class PrimeRewardManager:
data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)
return scores
def __call__(self, data: DataProto, return_dict: bool = False):
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn

View File

@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from verl.workers.reward_manager.abstract import AbstractRewardManager
__all__ = ["register", "get_reward_manager_cls"]
REWARD_MANAGER_REGISTRY = {}
REWARD_MANAGER_REGISTRY: dict[str, type[AbstractRewardManager]] = {}
def register(name):
def register(name: str) -> Callable[[type[AbstractRewardManager]], type[AbstractRewardManager]]:
"""Decorator to register a reward manager class with a given name.
Args:
@ -25,7 +29,7 @@ def register(name):
The name of the reward manager.
"""
def decorator(cls):
def decorator(cls: type[AbstractRewardManager]) -> type[AbstractRewardManager]:
if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls:
raise ValueError(
f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}"
@ -36,7 +40,7 @@ def register(name):
return decorator
def get_reward_manager_cls(name):
def get_reward_manager_cls(name: str) -> type[AbstractRewardManager]:
"""Get the reward manager class with a given name.
Args: