mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[CI] feat: add mypy
to pre-commit (#2614)
This commit is contained in:
committed by
GitHub
parent
dc8b5076c3
commit
f407887414
@ -7,6 +7,11 @@ repos:
|
||||
exclude: ^.*\.(ipynb)$
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v1.17.0'
|
||||
hooks:
|
||||
- id: mypy
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: autogen-trainer-cfg
|
||||
@ -29,4 +34,4 @@ repos:
|
||||
name: Check license
|
||||
entry: python3 tests/special_sanity/check_license.py --directory .
|
||||
language: python
|
||||
pass_filenames: false
|
||||
pass_filenames: false
|
||||
|
@ -65,6 +65,25 @@ ignore = [
|
||||
"UP035",
|
||||
]
|
||||
|
||||
# -------------------------------
|
||||
# tool.mypy - typechecking config
|
||||
# -------------------------------
|
||||
[tool.mypy]
|
||||
pretty = true
|
||||
ignore_missing_imports = true
|
||||
explicit_package_bases = true
|
||||
follow_imports = "skip"
|
||||
|
||||
# Blanket silence
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"verl.trainer.config.algorithm",
|
||||
"verl.trainer.ppo.core_algos",
|
||||
]
|
||||
ignore_errors = false
|
||||
|
||||
# -------------------------------
|
||||
# tool.setuptools - Additional config
|
||||
# -------------------------------
|
||||
|
@ -22,18 +22,31 @@ __all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"]
|
||||
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from verl.trainer.config import AlgoConfig
|
||||
|
||||
POLICY_LOSS_REGISTRY = {}
|
||||
PolicyLossFn = Callable[
|
||||
[
|
||||
torch.Tensor, # old_log_prob
|
||||
torch.Tensor, # log_prob
|
||||
torch.Tensor, # advantages
|
||||
torch.Tensor, # response_mask
|
||||
str, # loss_agg_mode
|
||||
Optional[DictConfig | AlgoConfig], # config
|
||||
],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
]
|
||||
|
||||
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
|
||||
|
||||
|
||||
def register_policy_loss(name):
|
||||
def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:
|
||||
"""Register a policy loss function with the given name.
|
||||
|
||||
Args:
|
||||
@ -43,7 +56,7 @@ def register_policy_loss(name):
|
||||
function: Decorator function that registers the policy loss function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: PolicyLossFn) -> PolicyLossFn:
|
||||
POLICY_LOSS_REGISTRY[name] = func
|
||||
return func
|
||||
|
||||
@ -68,10 +81,30 @@ def get_policy_loss_fn(name):
|
||||
return POLICY_LOSS_REGISTRY[loss_name]
|
||||
|
||||
|
||||
ADV_ESTIMATOR_REGISTRY = {}
|
||||
class AdvantageEstimator(str, Enum):
|
||||
"""Using an enumeration class to avoid spelling errors in adv_estimator.
|
||||
|
||||
Note(haibin.lin): this enum class is immutable after creation. Extending this
|
||||
enum for new estimators may not be necessary since users can always just call
|
||||
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
|
||||
estimator instead.
|
||||
"""
|
||||
|
||||
GAE = "gae"
|
||||
GRPO = "grpo"
|
||||
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
|
||||
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
|
||||
REMAX = "remax"
|
||||
RLOO = "rloo"
|
||||
OPO = "opo"
|
||||
GRPO_PASSK = "grpo_passk"
|
||||
GPG = "gpg"
|
||||
|
||||
|
||||
def register_adv_est(name_or_enum):
|
||||
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:
|
||||
"""Decorator to register a advantage estimator function with a given name.
|
||||
|
||||
Args:
|
||||
@ -108,26 +141,6 @@ def get_adv_estimator_fn(name_or_enum):
|
||||
return ADV_ESTIMATOR_REGISTRY[name]
|
||||
|
||||
|
||||
class AdvantageEstimator(str, Enum):
|
||||
"""Using an enumeration class to avoid spelling errors in adv_estimator.
|
||||
|
||||
Note(haibin.lin): this enum class is immutable after creation. Extending this
|
||||
enum for new estimators may not be necessary since users can always just call
|
||||
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
|
||||
estimator instead.
|
||||
"""
|
||||
|
||||
GAE = "gae"
|
||||
GRPO = "grpo"
|
||||
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
|
||||
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
|
||||
REMAX = "remax"
|
||||
RLOO = "rloo"
|
||||
OPO = "opo"
|
||||
GRPO_PASSK = "grpo_passk"
|
||||
GPG = "gpg"
|
||||
|
||||
|
||||
class AdaptiveKLController:
|
||||
"""
|
||||
Adaptive KL controller described in the paper:
|
||||
@ -822,7 +835,7 @@ def compute_policy_loss_clip_cov(
|
||||
advantages: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[AlgoConfig] = None,
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
@ -855,6 +868,10 @@ def compute_policy_loss_clip_cov(
|
||||
clip_cov_ub (float, optional):
|
||||
Upper bound for clipping covariance. Defaults to 5.0.
|
||||
"""
|
||||
assert config is not None
|
||||
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
|
||||
assert config.policy_loss is not None
|
||||
|
||||
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
|
||||
cliprange = config.clip_ratio
|
||||
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
|
||||
@ -912,7 +929,7 @@ def compute_policy_loss_kl_cov(
|
||||
advantages: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[AlgoConfig] = None,
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
@ -936,6 +953,10 @@ def compute_policy_loss_kl_cov(
|
||||
ppo_kl_coef (float, optional):
|
||||
Coefficient for the KL penalty term in the loss. Defaults to 1.
|
||||
"""
|
||||
assert config is not None
|
||||
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
|
||||
assert config.policy_loss is not None
|
||||
|
||||
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
|
||||
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0
|
||||
|
||||
|
Reference in New Issue
Block a user