diff --git a/pyproject.toml b/pyproject.toml index 9d24ae95b..1426b507b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,9 @@ ignore_errors = true module = [ "verl.trainer.config.algorithm", "verl.trainer.ppo.core_algos", +"verl.trainer.ppo.reward", +"verl.workers.reward_manager", +"verl.workers.reward_manager.*", ] ignore_errors = false diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py index 6362f7856..a71fa8ff4 100644 --- a/verl/trainer/ppo/reward.py +++ b/verl/trainer/ppo/reward.py @@ -12,14 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util import multiprocessing import os +import sys +import warnings from functools import partial +from typing import Any, Optional import ray +import torch +from omegaconf import DictConfig from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import get_reward_manager_cls +from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs): @@ -31,7 +39,7 @@ def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs): return raw_fn(*args, **merged_kwargs) -def get_custom_reward_fn(config): +def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]: """Load and return a custom reward function from external file. Dynamically imports a reward function from a specified file path and wraps @@ -50,8 +58,6 @@ def get_custom_reward_fn(config): RuntimeError: If there's an error loading the module from file. AttributeError: If the specified function name isn't found in the module. """ - import importlib.util - import sys reward_fn_config = config.get("custom_reward_function") or {} file_path = reward_fn_config.get("path") @@ -62,14 +68,17 @@ def get_custom_reward_fn(config): raise FileNotFoundError(f"Reward function file '{file_path}' not found.") spec = importlib.util.spec_from_file_location("custom_module", file_path) + assert spec is not None module = importlib.util.module_from_spec(spec) try: sys.modules["custom_module"] = module + assert spec.loader is not None spec.loader.exec_module(module) except Exception as e: raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e function_name = reward_fn_config.get("name") + assert function_name is not None if not hasattr(module, function_name): raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") @@ -81,7 +90,9 @@ def get_custom_reward_fn(config): return partial(_call_with_kwargs, raw_fn, reward_kwargs) -def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): +def load_reward_manager( + config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any +) -> AbstractRewardManager: """ Load and initialize a reward manager based on the configuration. @@ -94,7 +105,6 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): Returns: An instance of the specified reward manager class. """ - from verl.workers.reward_manager import get_reward_manager_cls # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: # naive: NaiveRewardManager @@ -138,7 +148,7 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): ) -def compute_reward(data: DataProto, reward_fn): +def compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]: """ Compute reward for a batch of data. Args: @@ -169,7 +179,6 @@ def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn assert config is not None and tokenizer is not None, ( "config and tokenizer must not be None when reward_fn is None" ) - import warnings warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2) reward_fn = load_reward_manager( diff --git a/verl/workers/reward_manager/abstract.py b/verl/workers/reward_manager/abstract.py new file mode 100644 index 000000000..b8c7d6e39 --- /dev/null +++ b/verl/workers/reward_manager/abstract.py @@ -0,0 +1,45 @@ +# Copyright 2023-2025 SGLang Team +# Copyright Amazon.com, Inc. or its affiliates. +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Callable + +import torch + +from verl.protocol import DataProto + +RawRewardFn = Callable[..., Any] + + +class AbstractRewardManager(ABC): + @abstractmethod + def __init__( + self, + tokenizer: Any, + num_examine: int, + compute_score: RawRewardFn | None, + reward_fn_key: str = "data_source", + **kwargs: Any, + ): + pass + + @abstractmethod + def __call__( + self, + data: DataProto, + return_dict: bool = False, + ) -> torch.Tensor | dict[str, Any]: + pass diff --git a/verl/workers/reward_manager/batch.py b/verl/workers/reward_manager/batch.py index 8d1b11228..989ca14f4 100644 --- a/verl/workers/reward_manager/batch.py +++ b/verl/workers/reward_manager/batch.py @@ -13,15 +13,17 @@ # limitations under the License. from collections import defaultdict +from typing import Any import torch from verl import DataProto from verl.workers.reward_manager import register +from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn @register("batch") -class BatchRewardManager: +class BatchRewardManager(AbstractRewardManager): """ A batch reward manager that computes rewards for a batch of data. @@ -33,7 +35,9 @@ class BatchRewardManager: reward_kwargs (dict): The keyword arguments to pass to the reward function. """ - def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs): + def __init__( + self, tokenizer, num_examine, compute_score: RawRewardFn, reward_fn_key="data_source", **reward_kwargs + ): self.tokenizer = tokenizer self.num_examine = num_examine self.compute_score = compute_score @@ -69,7 +73,7 @@ class BatchRewardManager: return scores - def __call__(self, data: DataProto, return_dict=False): + def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]: # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn if "rm_scores" in data.batch.keys(): if return_dict: @@ -87,7 +91,7 @@ class BatchRewardManager: scores = self.verify(data) rewards = [] - already_printed = {} + already_printed: dict[str, Any] = {} for i in range(len(data)): length = valid_response_lengths[i].item() diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index cb8b5cf22..bb6e0895f 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -19,10 +19,11 @@ import torch from verl import DataProto from verl.utils.reward_score import default_compute_score from verl.workers.reward_manager import register +from verl.workers.reward_manager.abstract import AbstractRewardManager @register("dapo") -class DAPORewardManager: +class DAPORewardManager(AbstractRewardManager): """The reward manager.""" def __init__( diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index f6f979eef..f10bbc636 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -13,16 +13,18 @@ # limitations under the License. from collections import defaultdict +from typing import Any import torch from verl import DataProto from verl.utils.reward_score import default_compute_score from verl.workers.reward_manager import register +from verl.workers.reward_manager.abstract import AbstractRewardManager @register("naive") -class NaiveRewardManager: +class NaiveRewardManager(AbstractRewardManager): """The reward manager.""" def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: @@ -41,7 +43,7 @@ class NaiveRewardManager: self.compute_score = compute_score or default_compute_score self.reward_fn_key = reward_fn_key # Store the key for accessing the data source - def __call__(self, data: DataProto, return_dict=False): + def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]: """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index f2c526b63..98c094f2c 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -15,7 +15,7 @@ import asyncio from concurrent.futures import ProcessPoolExecutor from functools import partial -from typing import Callable, Optional +from typing import Any, Callable, Optional import psutil import torch @@ -24,6 +24,7 @@ from transformers import PreTrainedTokenizer from verl import DataProto from verl.utils.reward_score import default_compute_score from verl.workers.reward_manager import register +from verl.workers.reward_manager.abstract import AbstractRewardManager async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): @@ -98,7 +99,7 @@ def run_reward_scoring(evaluation_func, completions, references, tasks, extra_in @register("prime") -class PrimeRewardManager: +class PrimeRewardManager(AbstractRewardManager): """ The Reward Manager used in https://github.com/PRIME-RL/PRIME """ @@ -147,7 +148,7 @@ class PrimeRewardManager: data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) return scores - def __call__(self, data: DataProto, return_dict: bool = False): + def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]: """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn diff --git a/verl/workers/reward_manager/registry.py b/verl/workers/reward_manager/registry.py index 3fc34efaa..4e255d8ac 100644 --- a/verl/workers/reward_manager/registry.py +++ b/verl/workers/reward_manager/registry.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable + +from verl.workers.reward_manager.abstract import AbstractRewardManager + __all__ = ["register", "get_reward_manager_cls"] -REWARD_MANAGER_REGISTRY = {} +REWARD_MANAGER_REGISTRY: dict[str, type[AbstractRewardManager]] = {} -def register(name): +def register(name: str) -> Callable[[type[AbstractRewardManager]], type[AbstractRewardManager]]: """Decorator to register a reward manager class with a given name. Args: @@ -25,7 +29,7 @@ def register(name): The name of the reward manager. """ - def decorator(cls): + def decorator(cls: type[AbstractRewardManager]) -> type[AbstractRewardManager]: if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls: raise ValueError( f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}" @@ -36,7 +40,7 @@ def register(name): return decorator -def get_reward_manager_cls(name): +def get_reward_manager_cls(name: str) -> type[AbstractRewardManager]: """Get the reward manager class with a given name. Args: