mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[misc] refactor: Add AbstractRewardManager
abstract class (#2763)
### What does this PR do? Adds a new `AbstractRewardManager` class to codify the interface for a reward manager. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
committed by
GitHub
parent
ae3506dd33
commit
a24241092d
@ -81,6 +81,9 @@ ignore_errors = true
|
|||||||
module = [
|
module = [
|
||||||
"verl.trainer.config.algorithm",
|
"verl.trainer.config.algorithm",
|
||||||
"verl.trainer.ppo.core_algos",
|
"verl.trainer.ppo.core_algos",
|
||||||
|
"verl.trainer.ppo.reward",
|
||||||
|
"verl.workers.reward_manager",
|
||||||
|
"verl.workers.reward_manager.*",
|
||||||
]
|
]
|
||||||
ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
|
@ -12,14 +12,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
import torch
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from verl import DataProto
|
from verl import DataProto
|
||||||
from verl.utils.reward_score import default_compute_score
|
from verl.utils.reward_score import default_compute_score
|
||||||
|
from verl.workers.reward_manager import get_reward_manager_cls
|
||||||
|
from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn
|
||||||
|
|
||||||
|
|
||||||
def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
|
def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
|
||||||
@ -31,7 +39,7 @@ def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):
|
|||||||
return raw_fn(*args, **merged_kwargs)
|
return raw_fn(*args, **merged_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_custom_reward_fn(config):
|
def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:
|
||||||
"""Load and return a custom reward function from external file.
|
"""Load and return a custom reward function from external file.
|
||||||
|
|
||||||
Dynamically imports a reward function from a specified file path and wraps
|
Dynamically imports a reward function from a specified file path and wraps
|
||||||
@ -50,8 +58,6 @@ def get_custom_reward_fn(config):
|
|||||||
RuntimeError: If there's an error loading the module from file.
|
RuntimeError: If there's an error loading the module from file.
|
||||||
AttributeError: If the specified function name isn't found in the module.
|
AttributeError: If the specified function name isn't found in the module.
|
||||||
"""
|
"""
|
||||||
import importlib.util
|
|
||||||
import sys
|
|
||||||
|
|
||||||
reward_fn_config = config.get("custom_reward_function") or {}
|
reward_fn_config = config.get("custom_reward_function") or {}
|
||||||
file_path = reward_fn_config.get("path")
|
file_path = reward_fn_config.get("path")
|
||||||
@ -62,14 +68,17 @@ def get_custom_reward_fn(config):
|
|||||||
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
|
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("custom_module", file_path)
|
spec = importlib.util.spec_from_file_location("custom_module", file_path)
|
||||||
|
assert spec is not None
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
try:
|
try:
|
||||||
sys.modules["custom_module"] = module
|
sys.modules["custom_module"] = module
|
||||||
|
assert spec.loader is not None
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
|
raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
|
||||||
|
|
||||||
function_name = reward_fn_config.get("name")
|
function_name = reward_fn_config.get("name")
|
||||||
|
assert function_name is not None
|
||||||
if not hasattr(module, function_name):
|
if not hasattr(module, function_name):
|
||||||
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
|
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
|
||||||
|
|
||||||
@ -81,7 +90,9 @@ def get_custom_reward_fn(config):
|
|||||||
return partial(_call_with_kwargs, raw_fn, reward_kwargs)
|
return partial(_call_with_kwargs, raw_fn, reward_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
|
def load_reward_manager(
|
||||||
|
config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
|
||||||
|
) -> AbstractRewardManager:
|
||||||
"""
|
"""
|
||||||
Load and initialize a reward manager based on the configuration.
|
Load and initialize a reward manager based on the configuration.
|
||||||
|
|
||||||
@ -94,7 +105,6 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
|
|||||||
Returns:
|
Returns:
|
||||||
An instance of the specified reward manager class.
|
An instance of the specified reward manager class.
|
||||||
"""
|
"""
|
||||||
from verl.workers.reward_manager import get_reward_manager_cls
|
|
||||||
|
|
||||||
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
|
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
|
||||||
# naive: NaiveRewardManager
|
# naive: NaiveRewardManager
|
||||||
@ -138,7 +148,7 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_reward(data: DataProto, reward_fn):
|
def compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Compute reward for a batch of data.
|
Compute reward for a batch of data.
|
||||||
Args:
|
Args:
|
||||||
@ -169,7 +179,6 @@ def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn
|
|||||||
assert config is not None and tokenizer is not None, (
|
assert config is not None and tokenizer is not None, (
|
||||||
"config and tokenizer must not be None when reward_fn is None"
|
"config and tokenizer must not be None when reward_fn is None"
|
||||||
)
|
)
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2)
|
warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2)
|
||||||
reward_fn = load_reward_manager(
|
reward_fn = load_reward_manager(
|
||||||
|
45
verl/workers/reward_manager/abstract.py
Normal file
45
verl/workers/reward_manager/abstract.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright 2023-2025 SGLang Team
|
||||||
|
# Copyright Amazon.com, Inc. or its affiliates.
|
||||||
|
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from verl.protocol import DataProto
|
||||||
|
|
||||||
|
RawRewardFn = Callable[..., Any]
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractRewardManager(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: Any,
|
||||||
|
num_examine: int,
|
||||||
|
compute_score: RawRewardFn | None,
|
||||||
|
reward_fn_key: str = "data_source",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
data: DataProto,
|
||||||
|
return_dict: bool = False,
|
||||||
|
) -> torch.Tensor | dict[str, Any]:
|
||||||
|
pass
|
@ -13,15 +13,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from verl import DataProto
|
from verl import DataProto
|
||||||
from verl.workers.reward_manager import register
|
from verl.workers.reward_manager import register
|
||||||
|
from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn
|
||||||
|
|
||||||
|
|
||||||
@register("batch")
|
@register("batch")
|
||||||
class BatchRewardManager:
|
class BatchRewardManager(AbstractRewardManager):
|
||||||
"""
|
"""
|
||||||
A batch reward manager that computes rewards for a batch of data.
|
A batch reward manager that computes rewards for a batch of data.
|
||||||
|
|
||||||
@ -33,7 +35,9 @@ class BatchRewardManager:
|
|||||||
reward_kwargs (dict): The keyword arguments to pass to the reward function.
|
reward_kwargs (dict): The keyword arguments to pass to the reward function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs):
|
def __init__(
|
||||||
|
self, tokenizer, num_examine, compute_score: RawRewardFn, reward_fn_key="data_source", **reward_kwargs
|
||||||
|
):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_examine = num_examine
|
self.num_examine = num_examine
|
||||||
self.compute_score = compute_score
|
self.compute_score = compute_score
|
||||||
@ -69,7 +73,7 @@ class BatchRewardManager:
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def __call__(self, data: DataProto, return_dict=False):
|
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||||
if "rm_scores" in data.batch.keys():
|
if "rm_scores" in data.batch.keys():
|
||||||
if return_dict:
|
if return_dict:
|
||||||
@ -87,7 +91,7 @@ class BatchRewardManager:
|
|||||||
|
|
||||||
scores = self.verify(data)
|
scores = self.verify(data)
|
||||||
rewards = []
|
rewards = []
|
||||||
already_printed = {}
|
already_printed: dict[str, Any] = {}
|
||||||
|
|
||||||
for i in range(len(data)):
|
for i in range(len(data)):
|
||||||
length = valid_response_lengths[i].item()
|
length = valid_response_lengths[i].item()
|
||||||
|
@ -19,10 +19,11 @@ import torch
|
|||||||
from verl import DataProto
|
from verl import DataProto
|
||||||
from verl.utils.reward_score import default_compute_score
|
from verl.utils.reward_score import default_compute_score
|
||||||
from verl.workers.reward_manager import register
|
from verl.workers.reward_manager import register
|
||||||
|
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||||
|
|
||||||
|
|
||||||
@register("dapo")
|
@register("dapo")
|
||||||
class DAPORewardManager:
|
class DAPORewardManager(AbstractRewardManager):
|
||||||
"""The reward manager."""
|
"""The reward manager."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -13,16 +13,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from verl import DataProto
|
from verl import DataProto
|
||||||
from verl.utils.reward_score import default_compute_score
|
from verl.utils.reward_score import default_compute_score
|
||||||
from verl.workers.reward_manager import register
|
from verl.workers.reward_manager import register
|
||||||
|
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||||
|
|
||||||
|
|
||||||
@register("naive")
|
@register("naive")
|
||||||
class NaiveRewardManager:
|
class NaiveRewardManager(AbstractRewardManager):
|
||||||
"""The reward manager."""
|
"""The reward manager."""
|
||||||
|
|
||||||
def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None:
|
def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None:
|
||||||
@ -41,7 +43,7 @@ class NaiveRewardManager:
|
|||||||
self.compute_score = compute_score or default_compute_score
|
self.compute_score = compute_score or default_compute_score
|
||||||
self.reward_fn_key = reward_fn_key # Store the key for accessing the data source
|
self.reward_fn_key = reward_fn_key # Store the key for accessing the data source
|
||||||
|
|
||||||
def __call__(self, data: DataProto, return_dict=False):
|
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||||
"""We will expand this function gradually based on the available datasets"""
|
"""We will expand this function gradually based on the available datasets"""
|
||||||
|
|
||||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@ -24,6 +24,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
from verl import DataProto
|
from verl import DataProto
|
||||||
from verl.utils.reward_score import default_compute_score
|
from verl.utils.reward_score import default_compute_score
|
||||||
from verl.workers.reward_manager import register
|
from verl.workers.reward_manager import register
|
||||||
|
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||||
|
|
||||||
|
|
||||||
async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):
|
async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):
|
||||||
@ -98,7 +99,7 @@ def run_reward_scoring(evaluation_func, completions, references, tasks, extra_in
|
|||||||
|
|
||||||
|
|
||||||
@register("prime")
|
@register("prime")
|
||||||
class PrimeRewardManager:
|
class PrimeRewardManager(AbstractRewardManager):
|
||||||
"""
|
"""
|
||||||
The Reward Manager used in https://github.com/PRIME-RL/PRIME
|
The Reward Manager used in https://github.com/PRIME-RL/PRIME
|
||||||
"""
|
"""
|
||||||
@ -147,7 +148,7 @@ class PrimeRewardManager:
|
|||||||
data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)
|
data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def __call__(self, data: DataProto, return_dict: bool = False):
|
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
|
||||||
"""We will expand this function gradually based on the available datasets"""
|
"""We will expand this function gradually based on the available datasets"""
|
||||||
|
|
||||||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
||||||
|
@ -12,12 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from verl.workers.reward_manager.abstract import AbstractRewardManager
|
||||||
|
|
||||||
__all__ = ["register", "get_reward_manager_cls"]
|
__all__ = ["register", "get_reward_manager_cls"]
|
||||||
|
|
||||||
REWARD_MANAGER_REGISTRY = {}
|
REWARD_MANAGER_REGISTRY: dict[str, type[AbstractRewardManager]] = {}
|
||||||
|
|
||||||
|
|
||||||
def register(name):
|
def register(name: str) -> Callable[[type[AbstractRewardManager]], type[AbstractRewardManager]]:
|
||||||
"""Decorator to register a reward manager class with a given name.
|
"""Decorator to register a reward manager class with a given name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -25,7 +29,7 @@ def register(name):
|
|||||||
The name of the reward manager.
|
The name of the reward manager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls: type[AbstractRewardManager]) -> type[AbstractRewardManager]:
|
||||||
if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls:
|
if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}"
|
f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}"
|
||||||
@ -36,7 +40,7 @@ def register(name):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def get_reward_manager_cls(name):
|
def get_reward_manager_cls(name: str) -> type[AbstractRewardManager]:
|
||||||
"""Get the reward manager class with a given name.
|
"""Get the reward manager class with a given name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
Reference in New Issue
Block a user