[CI] feat: add mypy to pre-commit (#2614)

This commit is contained in:
Frederick Robinson
2025-07-24 20:36:34 -07:00
committed by GitHub
parent dc8b5076c3
commit f407887414
3 changed files with 74 additions and 29 deletions

View File

@ -7,6 +7,11 @@ repos:
exclude: ^.*\.(ipynb)$
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.17.0'
hooks:
- id: mypy
- repo: local
hooks:
- id: autogen-trainer-cfg
@ -29,4 +34,4 @@ repos:
name: Check license
entry: python3 tests/special_sanity/check_license.py --directory .
language: python
pass_filenames: false
pass_filenames: false

View File

@ -65,6 +65,25 @@ ignore = [
"UP035",
]
# -------------------------------
# tool.mypy - typechecking config
# -------------------------------
[tool.mypy]
pretty = true
ignore_missing_imports = true
explicit_package_bases = true
follow_imports = "skip"
# Blanket silence
ignore_errors = true
[[tool.mypy.overrides]]
module = [
"verl.trainer.config.algorithm",
"verl.trainer.ppo.core_algos",
]
ignore_errors = false
# -------------------------------
# tool.setuptools - Additional config
# -------------------------------

View File

@ -22,18 +22,31 @@ __all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"]
from collections import defaultdict
from enum import Enum
from typing import Optional
from typing import Any, Callable, Optional
import numpy as np
import torch
from omegaconf import DictConfig
import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig
POLICY_LOSS_REGISTRY = {}
PolicyLossFn = Callable[
[
torch.Tensor, # old_log_prob
torch.Tensor, # log_prob
torch.Tensor, # advantages
torch.Tensor, # response_mask
str, # loss_agg_mode
Optional[DictConfig | AlgoConfig], # config
],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
def register_policy_loss(name):
def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:
"""Register a policy loss function with the given name.
Args:
@ -43,7 +56,7 @@ def register_policy_loss(name):
function: Decorator function that registers the policy loss function.
"""
def decorator(func):
def decorator(func: PolicyLossFn) -> PolicyLossFn:
POLICY_LOSS_REGISTRY[name] = func
return func
@ -68,10 +81,30 @@ def get_policy_loss_fn(name):
return POLICY_LOSS_REGISTRY[loss_name]
ADV_ESTIMATOR_REGISTRY = {}
class AdvantageEstimator(str, Enum):
"""Using an enumeration class to avoid spelling errors in adv_estimator.
Note(haibin.lin): this enum class is immutable after creation. Extending this
enum for new estimators may not be necessary since users can always just call
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
estimator instead.
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
def register_adv_est(name_or_enum):
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:
"""Decorator to register a advantage estimator function with a given name.
Args:
@ -108,26 +141,6 @@ def get_adv_estimator_fn(name_or_enum):
return ADV_ESTIMATOR_REGISTRY[name]
class AdvantageEstimator(str, Enum):
"""Using an enumeration class to avoid spelling errors in adv_estimator.
Note(haibin.lin): this enum class is immutable after creation. Extending this
enum for new estimators may not be necessary since users can always just call
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
estimator instead.
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
@ -822,7 +835,7 @@ def compute_policy_loss_clip_cov(
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[AlgoConfig] = None,
config: Optional[DictConfig | AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
@ -855,6 +868,10 @@ def compute_policy_loss_clip_cov(
clip_cov_ub (float, optional):
Upper bound for clipping covariance. Defaults to 5.0.
"""
assert config is not None
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
assert config.policy_loss is not None
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
cliprange = config.clip_ratio
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
@ -912,7 +929,7 @@ def compute_policy_loss_kl_cov(
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[AlgoConfig] = None,
config: Optional[DictConfig | AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
@ -936,6 +953,10 @@ def compute_policy_loss_kl_cov(
ppo_kl_coef (float, optional):
Coefficient for the KL penalty term in the loss. Defaults to 1.
"""
assert config is not None
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
assert config.policy_loss is not None
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0