[cfg] refactor: add ActorConfig, EngineConfig, and ActorWorker unit test, refactor validation code (#2621)

As initially mentioned in
https://github.com/volcengine/verl/discussions/1941, having structured
configuration classes in verl makes argument passing easier for testing
and validation.

This is an extended thread on the current implementation of
configuration schema in verl. Related PRs:
-  https://github.com/volcengine/verl/pull/2117
- https://github.com/volcengine/verl/pull/2621 

# Motivation 
By moving from loose `omegaconfig.DictConfig`-based parameters to
structured dataclasses, we gain:
- Type safety & IDE support when accessing fields (e.g. cfg.optim.lr).
- Validation hooks via __post_init__ in each class.
- Immutable defaults with controlled mutability (e.g., an extra field).
- Seamless Hydra/OmegaConf integration and easy per-recipe extension.

# Core: BaseConfig

hydra natively provides support for converting DictConfig to dataclass,
but dataclass does not support accessing attribute via `get()`. We
introduce a base class to provide backward compatibility and make the
change less abrupt for existing users.

All config dataclasses inherit from BaseConfig, which:
- Implements collections.abc.Mapping → dict-like iteration/access.
- Freezes attributes once set, unless listed in _mutable_fields.
- Provides an `extra: dict[str, Any]` for unchecked extensions.

```python
@dataclass
class BaseConfig(collections.abc.Mapping):
    """Dict-like, frozen dataclass with opt-in mutability."""
    _mutable_fields: set[str] = {"extra"}
    extra: dict[str, Any] = field(default_factory=dict)

    def __setattr__(self, name: str, value):
        if name in self.__dict__ and name not in self._mutable_fields:
            raise FrozenInstanceError(f"Field '{name}' is frozen")
        super().__setattr__(name, value)

    # Mapping methods: get, __getitem__, __iter__, __len__ …

```

# Example Config Classes (verl/trainer/config)

Each sub-component of the trainer has its own dataclass, inheriting
BaseConfig.
```yaml:
critic:
  checkpoint:
    _target_: verl.trainer.config.CheckpointConfig
    save_contents: ["model","optimizer","extra"]
    load_contents: ["model","optimizer","extra"]
    async_save: false
```
Definition: 
```python
@dataclass
class CheckpointConfig(BaseConfig):
    """What to save/load and async behavior."""
    save_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"])
    load_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"])
    async_save: bool = False

    def __post_init__(self):
        # validation checks go here after initialization


ckpt_cfg = CheckpointConfig(async_save=True)
print(ckpt_cfg.save_contents)
print(ckpt_cfg.get("save_contents", default_value))
print(ckpt_cfg["save_contents"])

# converting hydra-generated omegaconf.DictConfig to the dataclass config:
from verl.utils.config import omegaconf_to_dataclass
ckpt_cfg_from_cli = omegaconf_to_dataclass(config.critic.checkpoint)
```

# Extending existing config classes
Because now configs become structured, unexpected keys would raise
exceptions. To add new keys, there are two ways:
## Explicit class extensions:
```python
from verl.workers.config import FSDPActorConfig

@dataclass
class SPPOActorConfig(FSDPActorConfig):
    """Add SPPO-specific temperature/penalty."""
    sppo_eta: float = 1.0

```
When using yaml or from command line, update the target config class:
```yaml
hydra:
  searchpath:
    - file://verl/trainer/config
defaults:
  - ppo_trainer      # base trainer config
  - _self_               # then apply these overrides

actor_rollout_ref:
  actor:
    _target_:  recipe.sppo.config.SPPOActorConfig # **new target dataclass required for extension **
    sppo_eta: 1.0  
```
or directly from command line:
```bash
python main_sppo.py \
  actor_rollout_ref.actor._target_=recipe.sppo.config.SPPOActorConfig \
  actor_rollout_ref.actor.sppo_eta=1.0
```

## Leverage the `extra` field
Adding more keys to the `extra` field of any dataclass that inherits
from `BaseConfig` also works. This way there's no need to define your
own dataclass in python:
```yaml
hydra:
  searchpath:
    - file://verl/trainer/config
defaults:
  - ppo_trainer      # base trainer config
  - _self_               # then apply these overrides

actor_rollout_ref:
  actor:
    extra:
        sppo_eta: 1.0  
```

# Declaring mutable fields
For historical reasons some fields in the configs are mutated inplace in
the codebase such as batch size for data/sequence parallelism. We are in
the process of deprecating this kind of behavior. However, if you want
to intentionally mutate one field, specify it with the `_mutable_fields`
attr:
```python
@dataclass
class CheckpointConfig(BaseConfig):
    """What to save/load and async behavior."""
    _mutable_fields = BaseConfig._mutable_fields | {"save_contents"} # mark save_contents as mutable.

    save_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"])
    load_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"])
    async_save: bool = False
```

# Other helpful resources
verl default trainer configs combines the following config files
together, specified in the `_defaults_` field:
https://github.com/volcengine/verl/blob/main/verl/trainer/config/ppo_trainer.yaml#L1-L36
- verl/trainer/config/ppo_trainer.yaml  # main config for entrypoint 
- verl/trainer/config/actor/dp_actor.yaml 
- verl/trainer/config/critic/dp_critic.yaml 
- verl/trainer/config/reward_model/dp_reward_model.yaml 
- verl/trainer/config/rollout/rollout.yaml 

To quickly peek the default full config in a single file, you can check
the auto-generated full config in
https://github.com/volcengine/verl/blob/main/verl/trainer/config/_generated_ppo_trainer.yaml

# Change log and impact on existing code
This PR converts the following fields to structured dataclass in the
training pipeline. More can be done in future PRs (contributions from
the community is welcome)
- [x] actor_rollout_ref.actor
- [x] critic 
- [ ] actor_rollout_ref.rollout
- [ ] actor_rollout_ref.ref
- [ ] reward_model
- [ ] data
- [ ] trainer

Changes needed for existing code that added new fields to config:
- see recipe/sppo for an example 
- `OmegaConf.to_container(self.config.model.get("override_config",
OmegaConf.create()))` now has to manually changed to
`self.config.model.get("override_config", {})`. Because
OmegaConf.to_container expects a DictConfig but
config.model.override_config is already a dict.

# Other Breaking Changes
critic.optim.lr for megatron changed from 1e-6 to 1e-5

---------

Signed-off-by: ShareLer <ShareLe@163.com>
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Joel <wuxibin@bytedance.com>
Co-authored-by: Cheetah <1659275352@qq.com>
Co-authored-by: 杨睿 <yangruipis@163.com>
Co-authored-by: X. HU <huxiaobo@zju.edu.cn>
Co-authored-by: Le Xue <48175490+ShareLer@users.noreply.github.com>
Co-authored-by: Ziheng Jiang <ziheng@apache.org>
Co-authored-by: Blue Space <57280232+ETOgaosion@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
H
2025-07-23 11:45:14 -07:00
committed by GitHub
parent 8fdc4d3f20
commit 4de3ecf0f0
66 changed files with 2082 additions and 597 deletions

View File

@ -94,7 +94,10 @@ jobs:
pip3 install cupy-cuda12x
- name: Run all GPU unit tests
run: |
pytest -s -x --ignore-glob="*test_linear_cross_entropy_tp.py" --ignore-glob='*on_cpu.py' --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob='tests/special*' --ignore-glob="tests/experimental" tests/
pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob='*on_cpu.py' --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob='tests/special*' --ignore-glob="tests/experimental" tests/
- name: Testing LinearCrossEntropyTP Correctness, Computation Time and Memory Consumption
run: |
LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/utils/test_linear_cross_entropy_tp.py
LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/utils/test_special_linear_cross_entropy_tp.py
- name: Testing FSDP actor functionality
run: |
torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/actor/test_special_dp_actor.py

View File

@ -26,7 +26,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+actor_rollout_ref.actor.fsdp_config.use_orig_params=True \
actor_rollout_ref.actor.fsdp_config.use_orig_params=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \

View File

@ -34,7 +34,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.fsdp_config.param_offload=$OFFLOAD \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$OFFLOAD \
+actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang \

View File

@ -139,7 +139,7 @@ class RolloutWorker(ActorRolloutRefWorker):
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_shm = self.config.model.get("use_shm", False)
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)

View File

@ -18,7 +18,7 @@ import os
import torch
import torch.distributed
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.debug import (
@ -119,11 +119,9 @@ class RolloutWorker(ActorRolloutRefWorker):
importlib.import_module(self.config.model.external_lib)
from omegaconf import OmegaConf
from verl.utils.torch_dtypes import PrecisionType
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_transformer_config = {}
self.param_dtype = torch.bfloat16
self.dtype = PrecisionType.to_dtype(self.param_dtype)

View File

@ -41,7 +41,6 @@ reward_model:
fsdp_config:
min_num_params: 0
param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}
# grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload}
optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}
update: before # ``before`` for double-forward, ``after`` for single-forward
optim:

View File

@ -17,6 +17,7 @@ import warnings
import torch
import torch.distributed
from omegaconf import OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from verl import DataProto
@ -98,9 +99,7 @@ class PRIMERewardModelWorker(Worker):
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,

View File

@ -22,7 +22,7 @@ import psutil
import torch
import torch.distributed
from codetiming import Timer
from omegaconf import open_dict
from omegaconf import OmegaConf, open_dict
from torch.distributed.device_mesh import init_device_mesh
import verl.utils.torch_functional as verl_F
@ -83,10 +83,7 @@ class SPINRolloutRefWorker(ActorRolloutRefWorker):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)

22
recipe/sppo/config.py Normal file
View File

@ -0,0 +1,22 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from verl.workers.config import FSDPActorConfig
@dataclass
class SPPOActorConfig(FSDPActorConfig):
sppo_eta: float = 1.0

View File

@ -10,7 +10,17 @@ defaults:
actor_rollout_ref:
actor:
_target_: recipe.sppo.config.SPPOActorConfig
# sppo_eta is an additional hyperparameter for SPPO, not available in
# verl core. specifying _target_ with SPPOActorConfig is needed to
# extend verl ActorConfig with custom fields.
# additional, it is also possible to use the `extra` field natively supported
# by all verl core dataclasses, without having to define SPPOActorConfig
# extra:
# sppo_eta: 1.0
sppo_eta: 1.0
optim:
lr_warmup_steps: 15
rollout:

View File

@ -16,7 +16,7 @@
import logging
import os
from omegaconf import open_dict
from omegaconf import OmegaConf, open_dict
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
@ -43,10 +43,7 @@ class SPPOActorRolloutRefWorker(ActorRolloutRefWorker):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)

View File

@ -33,7 +33,16 @@ def init_config() -> DictConfig:
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose(config_name="ppo_trainer")
config = compose(
config_name="ppo_trainer",
overrides=[
"actor_rollout_ref.actor.use_dynamic_bsz=true",
# test sleep/wake_up with fsdp offload
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
],
)
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
config.actor_rollout_ref.model.path = model_path
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
@ -43,10 +52,6 @@ def init_config() -> DictConfig:
config.actor_rollout_ref.rollout.n = 4
config.actor_rollout_ref.rollout.agent.num_workers = 2
# test sleep/wake_up with fsdp offload
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
return config

View File

@ -40,7 +40,7 @@ python3 -m recipe.sppo.main_sppo \
trainer.critic_warmup=0 \
trainer.logger=console \
trainer.val_before_train=False \
trainer.n_gpus_per_node=8 \
trainer.n_gpus_per_node=$NUM_GPUS \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 \

View File

@ -62,6 +62,8 @@ def test_trainer_config_doc():
"verl/trainer/config/ppo_trainer.yaml",
"verl/trainer/config/actor/actor.yaml",
"verl/trainer/config/actor/dp_actor.yaml",
"verl/trainer/config/critic/critic.yaml",
"verl/trainer/config/critic/dp_critic.yaml",
"verl/trainer/config/ref/ref.yaml",
"verl/trainer/config/ref/dp_ref.yaml",
"verl/trainer/config/rollout/rollout.yaml",

View File

@ -258,7 +258,7 @@ actor_rollout_ref:
calculate_log_probs: False
# Nsight system profiler configs
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
discrete: False
all_ranks: False
@ -336,7 +336,7 @@ critic:
load_contents: ${critic.checkpoint.save_contents}
# Nsight system profiler configs
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
discrete: False
all_ranks: False
@ -379,7 +379,7 @@ reward_model:
memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB
# Nsight system profiler configs
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
discrete: False
all_ranks: False
@ -390,7 +390,7 @@ custom_reward_function:
name: compute_score
algorithm:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.AlgoConfig
gamma: 1.0
lam: 1.0
@ -399,7 +399,7 @@ algorithm:
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.KLControlConfig
type: fixed
kl_coef: 0.001
@ -407,8 +407,6 @@ algorithm:
target_kl: 0.1
use_pf_ppo: False
pf_ppo:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
_target_: verl.trainer.config.PFPPOConfig
reweight_method: pow # ["pow", "max_min", "max_random"]
weight_pow: 2.0

View File

@ -579,7 +579,7 @@ actor_rollout_ref:
# profiler configs
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
# True for each task has its own database, False for all tasks in one training step share one database.
@ -744,7 +744,7 @@ critic:
# the corresponding dataclass is verl.utils.profiler.ProfilerConfig.
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
# True for each task has its own database, False for all tasks in one training step share one database.
@ -858,7 +858,7 @@ reward_model:
# profiler configs
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
# True for each task has its own database, False for all tasks in one training step share one database.
@ -883,7 +883,7 @@ custom_reward_function:
# config for the algorithm
algorithm:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.AlgoConfig
# Discount factor for future rewards
@ -907,7 +907,7 @@ algorithm:
# KL control configuration
kl_ctrl:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.KLControlConfig
# KL control type: "fixed" or "adaptive"
@ -928,9 +928,6 @@ algorithm:
# Preference feedback PPO settings
pf_ppo:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
_target_: verl.trainer.config.PFPPOConfig
# Method for reweighting samples: "pow", "max_min", or "max_random"
reweight_method: pow

View File

@ -18,7 +18,7 @@ import numpy as np
import torch
from omegaconf import OmegaConf
from verl.trainer.config import AlgoConfig, KLControlConfig, PFPPOConfig
from verl.trainer.config import AlgoConfig, KLControlConfig
from verl.trainer.ppo.core_algos import (
compute_gae_advantage_return,
compute_grpo_outcome_advantage,
@ -49,7 +49,7 @@ class TestAlgoConfig(unittest.TestCase):
"target_kl": 0.05,
},
"use_pf_ppo": True,
"pf_ppo": {"_target_": "verl.trainer.config.PFPPOConfig", "reweight_method": "max_min", "weight_pow": 3.0},
"pf_ppo": {"reweight_method": "max_min", "weight_pow": 3.0},
}
self.omega_config = OmegaConf.create(self.config_dict)
@ -86,9 +86,8 @@ class TestAlgoConfig(unittest.TestCase):
self.assertEqual(config.kl_ctrl.target_kl, 0.05)
# Test PF PPO config
self.assertIsInstance(config.pf_ppo, PFPPOConfig)
self.assertEqual(config.pf_ppo.reweight_method, "max_min")
self.assertEqual(config.pf_ppo.weight_pow, 3.0)
self.assertEqual(config.pf_ppo.get("reweight_method"), "max_min")
self.assertEqual(config.pf_ppo.get("weight_pow"), 3.0)
def test_default_values(self):
"""Test that default values are properly set."""
@ -123,7 +122,7 @@ class TestAlgoConfig(unittest.TestCase):
# Check that nested configs are initialized
self.assertIsNotNone(minimal_config.kl_ctrl)
self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig)
self.assertIsNone(minimal_config.pf_ppo)
assert not minimal_config.pf_ppo
def test_config_init_from_yaml(self):
import os
@ -133,10 +132,9 @@ class TestAlgoConfig(unittest.TestCase):
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
cfg = compose(config_name="ppo_trainer")
algo_config = omega_conf_to_dataclass(cfg.algorithm)
from verl.trainer.config import AlgoConfig, PFPPOConfig
from verl.trainer.config import AlgoConfig
assert isinstance(algo_config, AlgoConfig)
assert isinstance(algo_config.pf_ppo, PFPPOConfig)
class TestAlgoCompute(unittest.TestCase):
@ -153,7 +151,7 @@ class TestAlgoCompute(unittest.TestCase):
kl_penalty="kl",
kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05),
use_pf_ppo=True,
pf_ppo=PFPPOConfig(reweight_method="max_min", weight_pow=3.0),
pf_ppo={"reweight_method": "max_min", "weight_pow": 3.0},
)
def test_advantage_estimator_with_cfg(self):

View File

@ -1,170 +0,0 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
import pytest
from hydra import compose, initialize_config_dir
from verl.trainer.config.config import CriticConfig, FSDPCriticConfig, MegatronCriticConfig
from verl.utils.config import omega_conf_to_dataclass
class TestCriticConfig:
"""Test suite for critic configuration dataclasses."""
@pytest.fixture
def config_dir(self):
"""Get the path to the config directory."""
return Path(__file__).parent.parent.parent.parent / "verl" / "trainer" / "config" / "critic"
def test_megatron_critic_config_instantiation_from_yaml(self, config_dir):
"""Test that MegatronCriticConfig can be instantiated from megatron_critic.yaml."""
yaml_path = config_dir / "megatron_critic.yaml"
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
test_config = compose(config_name="megatron_critic")
megatron_config_obj = omega_conf_to_dataclass(test_config)
assert isinstance(megatron_config_obj, MegatronCriticConfig)
assert isinstance(megatron_config_obj, CriticConfig)
expected_attrs = [
"strategy",
"rollout_n",
"optim",
"model",
"ppo_mini_batch_size",
"ppo_max_token_len_per_gpu",
"cliprange_value",
"get",
"nccl_timeout",
"megatron",
"load_weight",
]
for attr in expected_attrs:
assert hasattr(megatron_config_obj, attr), f"Missing attribute: {attr}"
assert callable(megatron_config_obj.get)
assert megatron_config_obj.strategy == "megatron"
def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):
"""Test that FSDPCriticConfig can be instantiated from dp_critic.yaml."""
yaml_path = config_dir / "dp_critic.yaml"
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
test_config = compose(config_name="dp_critic")
fsdp_config_obj = omega_conf_to_dataclass(test_config)
assert isinstance(fsdp_config_obj, FSDPCriticConfig)
assert isinstance(fsdp_config_obj, CriticConfig)
expected_attrs = [
"strategy",
"rollout_n",
"optim",
"model",
"ppo_mini_batch_size",
"ppo_max_token_len_per_gpu",
"cliprange_value",
"get",
"forward_micro_batch_size",
"forward_micro_batch_size_per_gpu",
"ulysses_sequence_parallel_size",
"grad_clip",
]
for attr in expected_attrs:
assert hasattr(fsdp_config_obj, attr), f"Missing attribute: {attr}"
assert callable(fsdp_config_obj.get)
assert fsdp_config_obj.strategy == "fsdp"
def test_config_inheritance_hierarchy(self):
"""Test that the inheritance hierarchy is correct."""
megatron_config = MegatronCriticConfig()
assert isinstance(megatron_config, CriticConfig)
assert isinstance(megatron_config, MegatronCriticConfig)
fsdp_config = FSDPCriticConfig()
assert isinstance(fsdp_config, CriticConfig)
assert isinstance(fsdp_config, FSDPCriticConfig)
critic_config = CriticConfig()
assert isinstance(critic_config, CriticConfig)
assert not isinstance(critic_config, MegatronCriticConfig)
assert not isinstance(critic_config, FSDPCriticConfig)
def test_config_dict_interface(self):
"""Test that configs provide dict-like interface from BaseConfig."""
config = CriticConfig()
assert "strategy" in config
assert config["strategy"] == "fsdp"
assert config.get("strategy") == "fsdp"
assert config.get("nonexistent_key", "default") == "default"
keys = list(config)
assert "strategy" in keys
assert "rollout_n" in keys
assert len(config) > 0
def test_frozen_fields_immutability(self):
"""Test that frozen fields raise exceptions when modified after creation."""
critic_config = CriticConfig()
frozen_fields = ["rollout_n", "strategy", "cliprange_value"]
for field_name in frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(critic_config, field_name, "modified_value")
megatron_config = MegatronCriticConfig()
megatron_frozen_fields = ["nccl_timeout", "load_weight", "data_loader_seed"]
for field_name in megatron_frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(megatron_config, field_name, "modified_value")
fsdp_config = FSDPCriticConfig()
fsdp_frozen_fields = ["ulysses_sequence_parallel_size", "grad_clip"]
for field_name in fsdp_frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(fsdp_config, field_name, "modified_value")
def test_batch_size_fields_modifiable(self):
"""Test that batch size fields can be modified after creation."""
critic_config = CriticConfig()
critic_config.ppo_mini_batch_size = 8
critic_config.ppo_micro_batch_size = 4
critic_config.ppo_micro_batch_size_per_gpu = 2
assert critic_config.ppo_mini_batch_size == 8
assert critic_config.ppo_micro_batch_size == 4
assert critic_config.ppo_micro_batch_size_per_gpu == 2
fsdp_config = FSDPCriticConfig()
fsdp_config.forward_micro_batch_size = 16
fsdp_config.forward_micro_batch_size_per_gpu = 8
assert fsdp_config.forward_micro_batch_size == 16
assert fsdp_config.forward_micro_batch_size_per_gpu == 8

View File

@ -20,6 +20,12 @@ from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
_BREAKING_CHANGES = [
"critic.optim.lr", # mcore critic lr init value 1e-6 -> 1e-5
"actor_rollout_ref.actor.optim.lr_warmup_steps", # None -> -1
"critic.optim.lr_warmup_steps", # None -> -1
]
class TestConfigComparison(unittest.TestCase):
"""Test that current configs match their legacy counterparts exactly."""
@ -81,7 +87,7 @@ class TestConfigComparison(unittest.TestCase):
)
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):
self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
else:
elif path not in _BREAKING_CHANGES:
self.assertEqual(
current_config,
legacy_config,

View File

@ -13,7 +13,7 @@
# limitations under the License.
import unittest
from dataclasses import dataclass
from dataclasses import dataclass, field
from omegaconf import OmegaConf
@ -30,13 +30,16 @@ class TestDataclass:
class TestTrainConfig:
batch_size: int
model: TestDataclass
override_config: dict = field(default_factory=dict)
_cfg_str = """train_config:
_target_: tests.utils.test_config_on_cpu.TestTrainConfig
batch_size: 32
model:
hidden_size: 768
activation: relu"""
activation: relu
override_config: {}"""
class TestConfigOnCPU(unittest.TestCase):

View File

@ -57,7 +57,7 @@ class TestProfilerConfig(unittest.TestCase):
from verl.utils.profiler.config import ProfilerConfig
# Create a new ProfilerConfig instance
config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0])
config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0], extra={"key": "value"})
# Test direct attribute assignment
with self.assertRaises(FrozenInstanceError):
@ -79,7 +79,9 @@ class TestProfilerConfig(unittest.TestCase):
with self.assertRaises(TypeError):
config["ranks"] = [1, 2, 3]
config["extra"]["key"] = "value"
assert config["extra"]["key"] == "value"
config["extra"]["key"] = "value2"
assert config["extra"]["key"] == "value2"
class TestNsightSystemsProfiler(unittest.TestCase):

View File

@ -0,0 +1,294 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import patch
import torch
import torch.nn as nn
from tensordict import TensorDict
from transformers import AutoModelForCausalLM, Qwen3Config
from verl import DataProto
from verl.workers.actor.dp_actor import DataParallelPPOActor
from verl.workers.config import FSDPActorConfig, OptimizerConfig
class MockTransformerModel(nn.Module):
"""Mock transformer model for testing DataParallelPPOActor"""
def __init__(self, vocab_size=1000, hidden_size=64):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4, batch_first=True), num_layers=2
)
self.lm_head = nn.Linear(hidden_size, vocab_size)
def forward(self, input_ids, attention_mask=None, position_ids=None, use_cache=False, **kwargs):
batch_size, seq_len = input_ids.shape
embeddings = self.embedding(input_ids)
hidden_states = self.transformer(embeddings)
logits = self.lm_head(hidden_states)
class MockOutput:
def __init__(self, logits):
self.logits = logits
return MockOutput(logits)
class TestDataParallelPPOActor(unittest.TestCase):
"""Test DataParallelPPOActor compute_log_prob and update_policy methods"""
def setUp(self):
"""Set up test fixtures"""
import os
import torch.distributed
backend = "cpu:gloo,cuda:nccl" if torch.cuda.is_available() else "gloo"
if not torch.distributed.is_initialized():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.distributed.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
init_method=os.environ.get("DIST_INIT_METHOD", "env://"),
)
self.mock_memory_info_patcher = patch("verl.utils.profiler.performance._get_current_mem_info")
self.mock_memory_info = self.mock_memory_info_patcher.start()
self.mock_memory_info.return_value = ("0.00", "0.00", "0.00", "0.00")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = FSDPActorConfig(
strategy="fsdp",
ppo_mini_batch_size=4,
ppo_micro_batch_size_per_gpu=2,
ppo_epochs=1,
clip_ratio=0.2,
entropy_coeff=0.01,
grad_clip=1.0,
use_dynamic_bsz=False,
use_torch_compile=False, # Disable torch.compile for testing
ulysses_sequence_parallel_size=1,
optim=OptimizerConfig(lr=1e-6),
)
self.mock_model = MockTransformerModel(vocab_size=1000, hidden_size=64).to(self.device)
self.mock_optimizer = torch.optim.Adam(self.mock_model.parameters(), lr=1e-4)
self.actor = DataParallelPPOActor(
config=self.config, actor_module=self.mock_model, actor_optimizer=self.mock_optimizer
)
def tearDown(self):
"""Clean up test fixtures"""
# repated init and destroy seems unstable with torch
# torch.distributed.destroy_process_group()
self.mock_memory_info_patcher.stop()
def _create_test_data_for_compute_log_prob(self):
"""Create test DataProto for compute_log_prob method"""
batch_size = 2
prompt_length = 8
response_length = 4
total_length = prompt_length + response_length
vocab_size = 1000
input_ids = torch.randint(0, vocab_size, (batch_size, total_length)).to(self.device)
attention_mask = torch.ones(batch_size, total_length).to(self.device)
position_ids = torch.arange(total_length).unsqueeze(0).expand(batch_size, -1).to(self.device)
responses = input_ids[:, -response_length:] # Last part is the response
tensor_dict = TensorDict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"responses": responses,
},
batch_size=[batch_size],
)
meta_info = {"micro_batch_size": batch_size, "temperature": 1.0, "use_dynamic_bsz": False}
return DataProto(batch=tensor_dict, meta_info=meta_info)
def _create_test_data_for_update_policy(self):
"""Create test DataProto for update_policy method"""
batch_size = 4 # Must match ppo_mini_batch_size
prompt_length = 8
response_length = 4
total_length = prompt_length + response_length
vocab_size = 1000
input_ids = torch.randint(0, vocab_size, (batch_size, total_length)).to(self.device)
attention_mask = torch.ones(batch_size, total_length).to(self.device)
position_ids = torch.arange(total_length).unsqueeze(0).expand(batch_size, -1).to(self.device)
responses = input_ids[:, -response_length:]
response_mask = torch.ones(batch_size, response_length).to(self.device)
old_log_probs = torch.randn(batch_size, response_length).to(self.device) * 0.1 # Small values
advantages = torch.randn(batch_size, response_length).to(self.device) * 0.5
tensor_dict = TensorDict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"responses": responses,
"response_mask": response_mask,
"old_log_probs": old_log_probs,
"advantages": advantages,
},
batch_size=[batch_size],
)
meta_info = {"temperature": 1.0}
return DataProto(batch=tensor_dict, meta_info=meta_info)
def test_compute_log_prob(self):
"""Test compute_log_prob method"""
data = self._create_test_data_for_compute_log_prob()
log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=True)
batch_size = data.batch["responses"].shape[0]
response_length = data.batch["responses"].shape[1]
self.assertIsInstance(log_probs, torch.Tensor)
self.assertEqual(log_probs.shape, (batch_size, response_length))
self.assertTrue(torch.all(torch.isfinite(log_probs)))
self.assertIsInstance(entropies, torch.Tensor)
self.assertEqual(entropies.shape, (batch_size, response_length))
self.assertTrue(torch.all(torch.isfinite(entropies)))
self.assertTrue(torch.all(entropies >= 0)) # Entropy should be non-negative
def test_compute_log_prob_without_entropy(self):
"""Test compute_log_prob method without entropy calculation"""
data = self._create_test_data_for_compute_log_prob()
log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=False)
batch_size = data.batch["responses"].shape[0]
response_length = data.batch["responses"].shape[1]
self.assertIsInstance(log_probs, torch.Tensor)
self.assertEqual(log_probs.shape, (batch_size, response_length))
self.assertTrue(torch.all(torch.isfinite(log_probs)))
self.assertIsNone(entropies)
def test_update_policy(self):
"""Test update_policy method"""
data = self._create_test_data_for_update_policy()
metrics = self.actor.update_policy(data)
self.assertIsInstance(metrics, dict)
expected_metric_keys = [
"actor/pg_loss",
"actor/pg_clipfrac",
"actor/ppo_kl",
"actor/pg_clipfrac_lower",
"actor/grad_norm",
]
for key in expected_metric_keys:
self.assertIn(key, metrics)
if isinstance(metrics[key], list):
self.assertTrue(all(torch.isfinite(torch.tensor(v)) for v in metrics[key]))
else:
self.assertIsInstance(metrics[key], (float, int))
self.assertTrue(torch.isfinite(torch.tensor(metrics[key])))
def test_dataparallelppoactor_initialization(self):
"""Test DataParallelPPOActor initialization"""
self.assertIsNotNone(self.actor.actor_module)
self.assertIsNotNone(self.actor.actor_optimizer)
self.assertEqual(self.actor.config, self.config)
self.assertEqual(self.actor.config.strategy, "fsdp")
self.assertEqual(self.actor.config.ppo_mini_batch_size, 4)
self.assertEqual(self.actor.config.clip_ratio, 0.2)
def test_dataparallelppoactor_with_qwen3_model(self):
"""Test DataParallelPPOActor with real Qwen3ForCausalLM model"""
qwen_config = Qwen3Config(
vocab_size=1000,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
max_position_embeddings=512,
torch_dtype=torch.float32,
use_cache=False,
)
with torch.device(self.device):
qwen_model = AutoModelForCausalLM.from_config(config=qwen_config, torch_dtype=torch.float32).to(self.device)
qwen_optimizer = torch.optim.Adam(qwen_model.parameters(), lr=1e-4)
qwen_actor = DataParallelPPOActor(config=self.config, actor_module=qwen_model, actor_optimizer=qwen_optimizer)
data = self._create_test_data_for_compute_log_prob()
log_probs, entropies = qwen_actor.compute_log_prob(data, calculate_entropy=True)
batch_size = data.batch["responses"].shape[0]
response_length = data.batch["responses"].shape[1]
self.assertIsInstance(log_probs, torch.Tensor)
self.assertEqual(log_probs.shape, (batch_size, response_length))
self.assertTrue(torch.all(torch.isfinite(log_probs)))
self.assertIsInstance(entropies, torch.Tensor)
self.assertEqual(entropies.shape, (batch_size, response_length))
self.assertTrue(torch.all(torch.isfinite(entropies)))
self.assertTrue(torch.all(entropies >= 0))
policy_data = self._create_test_data_for_update_policy()
metrics = qwen_actor.update_policy(policy_data)
self.assertIsInstance(metrics, dict)
expected_metric_keys = [
"actor/pg_loss",
"actor/pg_clipfrac",
"actor/ppo_kl",
"actor/pg_clipfrac_lower",
"actor/grad_norm",
]
for key in expected_metric_keys:
self.assertIn(key, metrics)
if isinstance(metrics[key], list):
self.assertTrue(all(torch.isfinite(torch.tensor(v)) for v in metrics[key]))
else:
self.assertIsInstance(metrics[key], (float, int))
self.assertTrue(torch.isfinite(torch.tensor(metrics[key])))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,240 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from verl.utils.config import omega_conf_to_dataclass
from verl.workers.config import ActorConfig, FSDPActorConfig, McoreActorConfig, OptimizerConfig
class TestActorConfig(unittest.TestCase):
"""Test the ActorConfig dataclass and its variants."""
def test_config_inheritance(self):
"""Test that the inheritance hierarchy works correctly."""
megatron_dict = {
"_target_": "verl.workers.config.McoreActorConfig",
"strategy": "megatron",
"ppo_mini_batch_size": 256,
"ppo_micro_batch_size_per_gpu": 256,
"clip_ratio": 0.2,
"optim": {
"_target_": "verl.workers.config.OptimizerConfig",
"lr": 0.1,
},
}
fsdp_dict = {
"_target_": "verl.workers.config.FSDPActorConfig",
"strategy": "fsdp",
"ppo_mini_batch_size": 256,
"ppo_micro_batch_size_per_gpu": 256,
"clip_ratio": 0.2,
"optim": {
"_target_": "verl.workers.config.OptimizerConfig",
"lr": 0.1,
},
}
megatron_config = omega_conf_to_dataclass(megatron_dict)
fsdp_config = omega_conf_to_dataclass(fsdp_dict)
self.assertIsInstance(megatron_config, ActorConfig)
self.assertIsInstance(fsdp_config, ActorConfig)
self.assertEqual(megatron_config.ppo_mini_batch_size, fsdp_config.ppo_mini_batch_size)
self.assertEqual(megatron_config.clip_ratio, fsdp_config.clip_ratio)
def test_actor_config_from_yaml(self):
"""Test creating ActorConfig from YAML file."""
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")):
cfg = compose(config_name="actor", overrides=["strategy=fsdp", "ppo_micro_batch_size_per_gpu=128"])
config = omega_conf_to_dataclass(cfg)
self.assertIsInstance(config, ActorConfig)
self.assertEqual(config.strategy, "fsdp")
def test_fsdp_actor_config_from_yaml(self):
"""Test creating FSDPActorConfig from YAML file."""
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")):
cfg = compose(config_name="dp_actor", overrides=["strategy=fsdp2", "ppo_micro_batch_size_per_gpu=128"])
config = omega_conf_to_dataclass(cfg)
self.assertIsInstance(config, FSDPActorConfig)
self.assertEqual(config.strategy, "fsdp2")
def test_megatron_actor_config_from_yaml(self):
"""Test creating McoreActorConfig from YAML file."""
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")):
cfg = compose(config_name="megatron_actor", overrides=["ppo_micro_batch_size_per_gpu=128"])
config = omega_conf_to_dataclass(cfg)
self.assertIsInstance(config, McoreActorConfig)
self.assertEqual(config.strategy, "megatron")
def test_config_get_method(self):
"""Test the get method for backward compatibility."""
config_dict = {
"_target_": "verl.workers.config.ActorConfig",
"strategy": "fsdp",
"ppo_mini_batch_size": 256,
"ppo_micro_batch_size_per_gpu": 256,
"optim": {
"_target_": "verl.workers.config.OptimizerConfig",
"lr": 0.1,
},
}
config = omega_conf_to_dataclass(config_dict)
self.assertEqual(config.get("strategy"), "fsdp")
self.assertEqual(config.get("ppo_mini_batch_size"), 256)
self.assertIsNone(config.get("non_existing"))
self.assertEqual(config.get("non_existing", "default"), "default")
def test_config_dict_like_access(self):
"""Test dictionary-like access to config fields."""
config_dict = {
"_target_": "verl.workers.config.ActorConfig",
"strategy": "fsdp",
"ppo_mini_batch_size": 256,
"ppo_micro_batch_size_per_gpu": 256,
"optim": {
"_target_": "verl.workers.config.OptimizerConfig",
"lr": 0.1,
},
}
config = omega_conf_to_dataclass(config_dict)
self.assertEqual(config["strategy"], "fsdp")
self.assertEqual(config["ppo_mini_batch_size"], 256)
field_names = list(config)
self.assertIn("strategy", field_names)
self.assertIn("ppo_mini_batch_size", field_names)
self.assertGreater(len(config), 0)
def test_frozen_fields_modification_raises_exception(self):
"""Test that modifying frozen fields raises an exception."""
config_dict = {
"_target_": "verl.workers.config.ActorConfig",
"strategy": "fsdp",
"ppo_mini_batch_size": 256,
"ppo_micro_batch_size_per_gpu": 256,
"optim": {
"_target_": "verl.workers.config.OptimizerConfig",
"lr": 0.1,
},
}
config = omega_conf_to_dataclass(config_dict)
with self.assertRaises(AttributeError):
config.strategy = "megatron"
with self.assertRaises(AttributeError):
config.clip_ratio = 0.5
config.ppo_mini_batch_size = 512 # This should work since it's not in frozen fields anymore
self.assertEqual(config.ppo_mini_batch_size, 512)
def test_actor_config_validation_exceptions(self):
"""Test that ActorConfig.__post_init__ raises appropriate validation exceptions."""
optim = OptimizerConfig(lr=0.1)
with self.assertRaises((ValueError, AssertionError)) as cm:
ActorConfig(
strategy="fsdp",
loss_agg_mode="invalid-mode",
use_dynamic_bsz=True,
optim=optim,
ppo_micro_batch_size_per_gpu=4,
)
self.assertIn("Invalid loss_agg_mode", str(cm.exception))
with self.assertRaises((ValueError, AssertionError)) as cm:
ActorConfig(
strategy="fsdp",
use_dynamic_bsz=False,
ppo_micro_batch_size=4,
ppo_micro_batch_size_per_gpu=2,
optim=optim,
)
self.assertIn("You have set both", str(cm.exception))
with self.assertRaises((ValueError, AssertionError)) as cm:
ActorConfig(
strategy="fsdp",
use_dynamic_bsz=False,
ppo_micro_batch_size=None,
ppo_micro_batch_size_per_gpu=None,
optim=optim,
)
self.assertIn("Please set at least one", str(cm.exception))
config = ActorConfig(
strategy="fsdp",
use_dynamic_bsz=True,
ppo_micro_batch_size=None,
ppo_micro_batch_size_per_gpu=None,
optim=optim,
)
self.assertIsNotNone(config) # Should not raise an exception
def test_fsdp_actor_config_validation_exceptions(self):
"""Test that FSDPActorConfig.validate() raises appropriate validation exceptions."""
optim = OptimizerConfig(lr=0.1)
config = FSDPActorConfig(
strategy="fsdp",
ulysses_sequence_parallel_size=2,
use_dynamic_bsz=True, # Skip batch size validation to focus on FSDP validation
optim=optim,
)
model_config = {"use_remove_padding": False}
with self.assertRaises(ValueError) as cm:
config.validate(n_gpus=8, train_batch_size=256, model_config=model_config)
self.assertIn("you must enable `use_remove_padding`", str(cm.exception))
def test_actor_config_validate_method_exceptions(self):
"""Test that ActorConfig.validate() raises appropriate validation exceptions."""
optim = OptimizerConfig(lr=0.1)
config = ActorConfig(
strategy="fsdp",
use_dynamic_bsz=False,
ppo_mini_batch_size=256,
ppo_micro_batch_size=8,
ppo_micro_batch_size_per_gpu=None, # Ensure only one batch size setting is used
optim=optim,
)
with self.assertRaises(ValueError) as cm:
config.validate(n_gpus=8, train_batch_size=128)
self.assertIn("train_batch_size", str(cm.exception))
with self.assertRaises(ValueError) as cm:
config.validate(n_gpus=16, train_batch_size=512)
self.assertIn("must be >= n_gpus", str(cm.exception))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,308 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
import pytest
from hydra import compose, initialize_config_dir
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.profiler import ProfilerConfig
from verl.workers.config import (
CriticConfig,
FSDPCriticConfig,
McoreCriticConfig,
OptimizerConfig,
)
class TestCriticConfig:
"""Test suite for critic configuration dataclasses."""
@pytest.fixture
def config_dir(self):
"""Get the path to the config directory."""
return Path(__file__).parent.parent.parent.parent / "verl" / "trainer" / "config" / "critic"
def test_megatron_critic_config_instantiation_from_yaml(self, config_dir):
"""Test that McoreCriticConfig can be instantiated from megatron_critic.yaml."""
yaml_path = config_dir / "megatron_critic.yaml"
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
test_config = compose(config_name="megatron_critic", overrides=["ppo_micro_batch_size_per_gpu=1"])
megatron_config_obj = omega_conf_to_dataclass(test_config)
assert isinstance(megatron_config_obj, McoreCriticConfig)
assert isinstance(megatron_config_obj, CriticConfig)
expected_attrs = [
"strategy",
"rollout_n",
"optim",
"model",
"ppo_mini_batch_size",
"ppo_max_token_len_per_gpu",
"cliprange_value",
"get",
"nccl_timeout",
"megatron",
"load_weight",
]
for attr in expected_attrs:
assert hasattr(megatron_config_obj, attr), f"Missing attribute: {attr}"
assert callable(megatron_config_obj.get)
assert megatron_config_obj.strategy == "megatron"
def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):
"""Test that FSDPCriticConfig can be instantiated from dp_critic.yaml."""
yaml_path = config_dir / "dp_critic.yaml"
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
test_config = compose(config_name="dp_critic", overrides=["ppo_micro_batch_size_per_gpu=1"])
fsdp_config_obj = omega_conf_to_dataclass(test_config)
assert isinstance(fsdp_config_obj, FSDPCriticConfig)
assert isinstance(fsdp_config_obj, CriticConfig)
expected_attrs = [
"strategy",
"rollout_n",
"optim",
"model",
"ppo_mini_batch_size",
"ppo_max_token_len_per_gpu",
"cliprange_value",
"get",
"forward_micro_batch_size",
"forward_micro_batch_size_per_gpu",
"ulysses_sequence_parallel_size",
"grad_clip",
]
for attr in expected_attrs:
assert hasattr(fsdp_config_obj, attr), f"Missing attribute: {attr}"
assert callable(fsdp_config_obj.get)
assert fsdp_config_obj.strategy == "fsdp"
def test_config_inheritance_hierarchy(self):
"""Test that the inheritance hierarchy is correct."""
optim = OptimizerConfig(lr=0.1)
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
assert isinstance(megatron_config, CriticConfig)
assert isinstance(megatron_config, McoreCriticConfig)
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
assert isinstance(fsdp_config, CriticConfig)
assert isinstance(fsdp_config, FSDPCriticConfig)
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
assert isinstance(critic_config, CriticConfig)
assert not isinstance(critic_config, McoreCriticConfig)
assert not isinstance(critic_config, FSDPCriticConfig)
def test_config_dict_interface(self):
"""Test that configs provide dict-like interface from BaseConfig."""
optim = OptimizerConfig(lr=0.1)
config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
assert "strategy" in config
assert config["strategy"] == "fsdp2"
assert config.get("strategy") == "fsdp2"
assert config.get("nonexistent_key", "default") == "default"
keys = list(config)
assert "strategy" in keys
assert "rollout_n" in keys
assert len(config) > 0
def test_frozen_fields_immutability(self):
"""Test that frozen fields raise exceptions when modified after creation."""
optim = OptimizerConfig(lr=0.1)
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
frozen_fields = ["rollout_n", "strategy", "cliprange_value"]
for field_name in frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(critic_config, field_name, "modified_value")
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
megatron_frozen_fields = ["nccl_timeout", "load_weight", "data_loader_seed"]
for field_name in megatron_frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(megatron_config, field_name, "modified_value")
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
fsdp_frozen_fields = ["ulysses_sequence_parallel_size", "grad_clip"]
for field_name in fsdp_frozen_fields:
with pytest.raises((AttributeError, TypeError, ValueError)):
setattr(fsdp_config, field_name, "modified_value")
def test_batch_size_fields_modifiable(self):
"""Test that batch size fields can be modified after creation."""
optim = OptimizerConfig(lr=0.1)
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
critic_config.ppo_mini_batch_size = 8
critic_config.ppo_micro_batch_size = 4
critic_config.ppo_micro_batch_size_per_gpu = 2
assert critic_config.ppo_mini_batch_size == 8
assert critic_config.ppo_micro_batch_size == 4
assert critic_config.ppo_micro_batch_size_per_gpu == 2
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
fsdp_config.forward_micro_batch_size = 16
fsdp_config.forward_micro_batch_size_per_gpu = 8
assert fsdp_config.forward_micro_batch_size == 16
assert fsdp_config.forward_micro_batch_size_per_gpu == 8
def test_profiler_config_type_validation(self):
"""Test that profiler field has correct type and validation."""
optim = OptimizerConfig(lr=0.1)
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
assert isinstance(critic_config.profiler, ProfilerConfig)
assert critic_config.profiler.discrete is False
assert critic_config.profiler.all_ranks is False
assert critic_config.profiler.ranks == []
custom_profiler = ProfilerConfig(discrete=True, all_ranks=True, ranks=[0, 1])
critic_config_custom = CriticConfig(
profiler=custom_profiler, ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim
)
assert isinstance(critic_config_custom.profiler, ProfilerConfig)
assert critic_config_custom.profiler.discrete is True
assert critic_config_custom.profiler.all_ranks is True
assert critic_config_custom.profiler.ranks == [0, 1]
profiler1 = ProfilerConfig(discrete=True, ranks=[0, 1])
profiler2 = ProfilerConfig(all_ranks=True, ranks=[1, 2])
union_result = profiler1.union(profiler2)
assert union_result.discrete is True
assert union_result.all_ranks is True
assert set(union_result.ranks) == {0, 1, 2}
intersect_result = profiler1.intersect(profiler2)
assert intersect_result.discrete is False
assert intersect_result.all_ranks is False
assert intersect_result.ranks == [1]
def test_critic_config_validation_logic(self):
"""Test the __post_init__ validation logic for CriticConfig."""
optim = OptimizerConfig(lr=0.1)
valid_config = CriticConfig(
strategy="fsdp2", ppo_micro_batch_size_per_gpu=2, use_dynamic_bsz=False, optim=optim
)
assert valid_config.ppo_micro_batch_size_per_gpu == 2
valid_config2 = CriticConfig(
strategy="fsdp2",
ppo_micro_batch_size_per_gpu=None,
ppo_micro_batch_size=4,
ppo_mini_batch_size=8,
use_dynamic_bsz=False,
optim=optim,
)
assert valid_config2.ppo_micro_batch_size == 4
dynamic_config = CriticConfig(
strategy="fsdp2", ppo_micro_batch_size_per_gpu=2, use_dynamic_bsz=True, optim=optim
)
assert dynamic_config.use_dynamic_bsz is True
with pytest.raises(ValueError, match="You have set both.*micro_batch_size.*AND.*micro_batch_size_per_gpu"):
CriticConfig(
strategy="fsdp2",
ppo_micro_batch_size=4,
ppo_micro_batch_size_per_gpu=2,
use_dynamic_bsz=False,
optim=optim,
)
with pytest.raises(
ValueError, match="Please set at least one of.*micro_batch_size.*or.*micro_batch_size_per_gpu"
):
CriticConfig(
strategy="fsdp2",
ppo_micro_batch_size=None,
ppo_micro_batch_size_per_gpu=None,
use_dynamic_bsz=False,
optim=optim,
)
def test_micro_batch_size_divisibility_validation(self):
"""Test micro batch size divisibility validation in __post_init__."""
optim = OptimizerConfig(lr=0.1)
valid_config = CriticConfig(
strategy="fsdp2", ppo_micro_batch_size_per_gpu=2, ppo_mini_batch_size=8, use_dynamic_bsz=False, optim=optim
)
assert valid_config.ppo_mini_batch_size == 8
assert valid_config.ppo_micro_batch_size_per_gpu == 2
valid_config_with_mbs = CriticConfig(
strategy="fsdp2", ppo_mini_batch_size=8, ppo_micro_batch_size=4, use_dynamic_bsz=False, optim=optim
)
assert valid_config_with_mbs.ppo_mini_batch_size == 8
assert valid_config_with_mbs.ppo_micro_batch_size == 4
with pytest.raises(ValueError, match="ppo_mini_batch_size.*must be divisible by.*ppo_micro_batch_size"):
CriticConfig(
strategy="fsdp2", ppo_mini_batch_size=7, ppo_micro_batch_size=4, use_dynamic_bsz=False, optim=optim
)
dynamic_config = CriticConfig(
strategy="fsdp2", ppo_mini_batch_size=7, ppo_micro_batch_size=4, use_dynamic_bsz=True, optim=optim
)
assert dynamic_config.use_dynamic_bsz is True
def test_fsdp_sequence_parallelism_validation(self):
"""Test FSDP sequence parallelism validation in FSDPCriticConfig.__post_init__."""
optim = OptimizerConfig(lr=0.1)
valid_config = FSDPCriticConfig(
ppo_micro_batch_size_per_gpu=2,
ulysses_sequence_parallel_size=2,
model={"use_remove_padding": True},
optim=optim,
)
assert valid_config.ulysses_sequence_parallel_size == 2
with pytest.raises(
ValueError, match="When using sequence parallelism for critic, you must enable.*use_remove_padding"
):
FSDPCriticConfig(
ppo_micro_batch_size_per_gpu=2,
ulysses_sequence_parallel_size=2,
model={"use_remove_padding": False},
optim=optim,
)
valid_config_no_sp = FSDPCriticConfig(
ppo_micro_batch_size_per_gpu=2,
ulysses_sequence_parallel_size=1,
model={"use_remove_padding": False},
optim=optim,
)
assert valid_config_no_sp.ulysses_sequence_parallel_size == 1

View File

@ -0,0 +1,67 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from verl.workers.config.engine import FSDPEngineConfig, McoreEngineConfig
class TestMcoreEngineConfig:
def test_default_values(self):
config = McoreEngineConfig()
assert config.tensor_model_parallel_size == 1
assert config.sequence_parallel is False # Should be auto-corrected
assert config.seed == 42
def test_post_init_validation(self):
# Test TP size 1 forces sequence_parallel=False
config = McoreEngineConfig(tensor_model_parallel_size=1)
assert config.sequence_parallel is False
# Test TP >1 keeps sequence_parallel=True
config = McoreEngineConfig(tensor_model_parallel_size=2)
assert config.sequence_parallel is True
def test_mutable_fields(self):
config = McoreEngineConfig()
config.sequence_parallel = True # Should be mutable
with pytest.raises(AttributeError):
config.tensor_model_parallel_size = 2 # Frozen field
@pytest.mark.parametrize("offload_field", ["param_offload", "grad_offload", "optimizer_offload"])
def test_offload_flags(self, offload_field):
config = McoreEngineConfig(**{offload_field: True})
assert getattr(config, offload_field) is True
class TestFSDPEngineConfigCPU:
def test_default_values(self):
config = FSDPEngineConfig()
assert config.param_offload is False
assert config.optimizer_offload is False
assert config.fsdp_size == -1
@pytest.mark.parametrize(
"offload_params",
[{"param_offload": True}, {"optimizer_offload": True}, {"param_offload": True, "optimizer_offload": True}],
)
def test_offload_combinations(self, offload_params):
config = FSDPEngineConfig(**offload_params)
assert config.param_offload == offload_params.get("param_offload", False)
assert config.optimizer_offload == offload_params.get("optimizer_offload", False)
def test_wrap_policy_configuration(self):
test_policy = {"layer_class": "TransformerBlock"}
config = FSDPEngineConfig(wrap_policy=test_policy)
assert config.wrap_policy == test_policy

View File

@ -0,0 +1,39 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from verl.workers.config.optimizer import FSDPOptimizerConfig
class TestFSDPOptimizerConfigCPU:
def test_default_configuration(self):
config = FSDPOptimizerConfig(lr=0.1)
assert config.min_lr_ratio is None
assert config.warmup_style == "constant"
assert config.num_cycles == 0.5
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
def test_valid_warmup_styles(self, warmup_style):
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
assert config.warmup_style == warmup_style
def test_invalid_warmup_style(self):
with pytest.raises((ValueError, AssertionError)):
FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1)
@pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5])
def test_num_cycles_configuration(self, num_cycles):
config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1)
assert config.num_cycles == num_cycles

View File

@ -51,7 +51,14 @@ def init_config(n_gpus_per_node) -> DictConfig:
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose(config_name="ppo_trainer")
config = compose(
config_name="ppo_trainer",
overrides=[
"actor_rollout_ref.actor.use_dynamic_bsz=true",
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
],
)
config.trainer.n_gpus_per_node = n_gpus_per_node
config.data.train_batch_size = 128
config.data.return_raw_chat = True
@ -64,10 +71,6 @@ def init_config(n_gpus_per_node) -> DictConfig:
config.actor_rollout_ref.rollout.response_length = 4096
config.actor_rollout_ref.rollout.n = 16
# test sleep/wake_up with fsdp offload
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
return config

View File

@ -33,7 +33,14 @@ def init_config() -> DictConfig:
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose(config_name="ppo_trainer")
config = compose(
config_name="ppo_trainer",
overrides=[
"actor_rollout_ref.actor.use_dynamic_bsz=true",
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
],
)
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
config.actor_rollout_ref.model.path = model_path
config.actor_rollout_ref.rollout.mode = "async"
@ -41,10 +48,6 @@ def init_config() -> DictConfig:
config.actor_rollout_ref.rollout.prompt_length = 4096
config.actor_rollout_ref.rollout.response_length = 4096
# test sleep/wake_up with fsdp offload
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
return config

View File

@ -13,33 +13,28 @@
# limitations under the License.
import collections
from dataclasses import (
dataclass,
field,
fields, # Import the fields function to inspect dataclass fields
)
from dataclasses import FrozenInstanceError, dataclass, field, fields
from typing import Any
# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary
@dataclass
class BaseConfig(collections.abc.Mapping):
"""The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config.
"""The BaseConfig provides dict-like interface for a dataclass config.
The BaseConfig class implements the Mapping Abstract Base Class.
By default all fields in the config is not mutable, unless specified in
"_mutable_fields". The BaseConfig class implements the Mapping Abstract Base Class.
This allows instances of this class to be used like dictionaries.
"""
_mutable_fields = {"extra"}
extra: dict[str, Any] = field(default_factory=dict)
def __setattr__(self, name: str, value):
# if the field already exists (i.e. was set in __init__)
# and is in our frozen list, block assignment
if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__:
from dataclasses import FrozenInstanceError
"""Set the value of an attribute. Check if the attr is mutable before setting the value."""
# If the field already exists, it's considered frozen unless it's in _mutable_fields
if name in self.__dict__ and name not in getattr(self, "_mutable_fields", set()):
raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified")
# otherwise do the normal thing
super().__setattr__(name, value)
def get(self, key: str, default: Any = None) -> Any:

View File

@ -900,6 +900,50 @@ class DataProto:
meta_info=self.meta_info,
)
def get_data_info(self) -> str:
"""Return formatted information about stored data with nested type details.
Returns:
str: Formatted string showing tensor details and recursive metadata types
"""
info = ["batch"]
for key, tensor in self.batch.items():
if hasattr(tensor, "shape") and hasattr(tensor, "dtype") and hasattr(tensor, "device"):
info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype}) {tensor.device}")
elif hasattr(tensor, "shape") and hasattr(tensor, "dtype"):
info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype})")
else:
info.append(f" {key}: {type(tensor).__name__}")
info.append("non_tensor_batch")
for key, array in self.non_tensor_batch.items():
info.append(f" {key}: ndarray{array.shape} ({array.dtype})")
info.append("meta_info")
for k, v in self.meta_info.items():
type_info = self._get_type_info(v)
info.append(f" {k}: {type_info}")
return "\n".join(info)
def _get_type_info(self, value):
"""Recursively get type information for nested structures"""
if isinstance(value, list):
elem_types = {self._get_type_info(v) for v in value[:3]}
return f"list[{'|'.join(elem_types) if elem_types else '...'}]"
if isinstance(value, tuple):
elem_types = [self._get_type_info(v) for v in value]
return f"tuple({', '.join(elem_types)})"
if isinstance(value, dict):
if not value:
return "dict"
k, v = next(iter(value.items()))
return f"dict[{self._get_type_info(k)}: {self._get_type_info(v)}]"
if isinstance(value, np.ndarray):
return f"ndarray{value.shape} ({value.dtype})"
return type(value).__name__
@dataclass
class DataProtoFuture:

View File

@ -12,15 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .algorithm import AlgoConfig, FilterGroupsConfig, KLControlConfig, PFPPOConfig
from .config import CriticConfig, FSDPCriticConfig, MegatronCriticConfig
from .algorithm import * # noqa
from .config import * # noqa
from . import config, algorithm
__all__ = [
"AlgoConfig",
"CriticConfig",
"FilterGroupsConfig",
"FSDPCriticConfig",
"KLControlConfig",
"MegatronCriticConfig",
"PFPPOConfig",
]
__all__ = config.__all__ + algorithm.__all__

View File

@ -5,6 +5,7 @@
actor_rollout_ref:
actor:
_target_: verl.workers.config.McoreActorConfig
strategy: megatron
ppo_mini_batch_size: 256
ppo_micro_batch_size: null
@ -15,6 +16,7 @@ actor_rollout_ref:
clip_ratio_low: 0.2
clip_ratio_high: 0.2
policy_loss:
_target_: verl.workers.config.PolicyLossConfig
loss_mode: vanilla
clip_cov_ratio: 0.0002
clip_cov_lb: 1.0
@ -31,6 +33,7 @@ actor_rollout_ref:
ppo_epochs: 1
shuffle: false
checkpoint:
_target_: verl.trainer.config.CheckpointConfig
save_contents:
- model
- optimizer
@ -42,10 +45,11 @@ actor_rollout_ref:
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr_warmup_steps: -1
_target_: verl.workers.config.McoreOptimizerConfig
optimizer: adam
clip_grad: 1.0
lr_warmup_init: 0.0
lr_warmup_steps: null
lr_decay_steps: null
lr_decay_style: constant
min_lr: 0.0
@ -53,9 +57,11 @@ actor_rollout_ref:
lr_wsd_decay_style: exponential
lr_wsd_decay_steps: null
use_checkpoint_opt_param_scheduler: false
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
data_loader_seed: null
load_weight: true
megatron:
_target_: verl.workers.config.McoreEngineConfig
param_offload: false
grad_offload: false
optimizer_offload: false
@ -92,6 +98,7 @@ actor_rollout_ref:
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
megatron:
_target_: verl.workers.config.MegatronEngineConfig
param_offload: false
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
@ -279,17 +286,20 @@ data:
path: null
name: null
critic:
_target_: verl.workers.config.McoreCriticConfig
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
strategy: megatron
enable: null
optim:
lr: 1.0e-05
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr_warmup_steps: -1
_target_: verl.workers.config.McoreOptimizerConfig
optimizer: adam
lr: 1.0e-06
clip_grad: 1.0
lr_warmup_init: 0.0
lr_warmup_steps: null
lr_decay_steps: null
lr_decay_style: linear
min_lr: 0.0
@ -306,6 +316,7 @@ critic:
freeze_moe_router: false
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
_target_: verl.trainer.config.BaseModelConfig
ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}
ppo_micro_batch_size: null
ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}
@ -317,6 +328,7 @@ critic:
cliprange_value: 0.5
loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
checkpoint:
_target_: verl.trainer.config.CheckpointConfig
save_contents:
- model
- optimizer
@ -328,9 +340,9 @@ critic:
discrete: false
all_ranks: false
ranks: []
_target_: verl.trainer.config.MegatronCriticConfig
nccl_timeout: 600
megatron:
_target_: verl.workers.config.McoreEngineConfig
param_offload: false
grad_offload: false
optimizer_offload: false
@ -376,6 +388,7 @@ reward_model:
ranks: []
nccl_timeout: 600
megatron:
_target_: verl.workers.config.MegatronEngineConfig
param_offload: false
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
@ -410,7 +423,6 @@ algorithm:
target_kl: 0.1
use_pf_ppo: false
pf_ppo:
_target_: verl.trainer.config.PFPPOConfig
reweight_method: pow
weight_pow: 2.0
ray_init:

View File

@ -5,6 +5,7 @@
actor_rollout_ref:
actor:
_target_: verl.workers.config.FSDPActorConfig
strategy: fsdp
ppo_mini_batch_size: 256
ppo_micro_batch_size: null
@ -15,6 +16,7 @@ actor_rollout_ref:
clip_ratio_low: 0.2
clip_ratio_high: 0.2
policy_loss:
_target_: verl.workers.config.PolicyLossConfig
loss_mode: vanilla
clip_cov_ratio: 0.0002
clip_cov_lb: 1.0
@ -31,25 +33,30 @@ actor_rollout_ref:
ppo_epochs: 1
shuffle: false
checkpoint:
_target_: verl.trainer.config.CheckpointConfig
save_contents:
- model
- optimizer
- extra
load_contents: ${.save_contents}
async_save: false
optim:
lr: 1.0e-06
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr_warmup_steps: -1
_target_: verl.workers.config.FSDPOptimizerConfig
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
grad_clip: 1.0
ulysses_sequence_parallel_size: 1
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
min_num_params: 0
param_offload: false
@ -58,6 +65,7 @@ actor_rollout_ref:
reshard_after_forward: true
fsdp_size: -1
forward_prefetch: false
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
ref:
strategy: ${actor_rollout_ref.actor.strategy}
use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
@ -66,11 +74,12 @@ actor_rollout_ref:
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
min_num_params: 0
param_offload: false
reshard_after_forward: true
forward_prefetch: false
wrap_policy:
min_num_params: 0
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
@ -249,13 +258,17 @@ data:
path: null
name: null
critic:
_target_: verl.workers.config.FSDPCriticConfig
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
strategy: fsdp
enable: null
optim:
lr: 1.0e-05
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr: 1.0e-05
lr_warmup_steps: -1
_target_: verl.workers.config.FSDPOptimizerConfig
min_lr_ratio: null
warmup_style: constant
model:
@ -264,11 +277,13 @@ critic:
override_config: {}
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
_target_: verl.workers.config.FSDPCriticModelCfg
use_shm: false
enable_gradient_checkpointing: true
enable_activation_offload: false
use_remove_padding: false
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
param_offload: false
optimizer_offload: false
offload_policy: false
@ -291,17 +306,18 @@ critic:
cliprange_value: 0.5
loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
checkpoint:
_target_: verl.trainer.config.CheckpointConfig
save_contents:
- model
- optimizer
- extra
load_contents: ${.save_contents}
async_save: false
profiler:
_target_: verl.utils.profiler.ProfilerConfig
discrete: false
all_ranks: false
ranks: []
_target_: verl.trainer.config.FSDPCriticConfig
forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}
forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}
ulysses_sequence_parallel_size: 1
@ -318,6 +334,7 @@ reward_model:
use_remove_padding: false
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
min_num_params: 0
param_offload: false
@ -360,7 +377,6 @@ algorithm:
target_kl: 0.1
use_pf_ppo: false
pf_ppo:
_target_: verl.trainer.config.PFPPOConfig
reweight_method: pow
weight_pow: 2.0
ray_init:

View File

@ -4,6 +4,9 @@
# 3. Inline comments (after a field on the same line) are not allowed.
# 4. Indentation level is respected for nested fields.
# Target class for this configuration
_target_: verl.workers.config.ActorConfig
# the abstract actor configs
# fsdp, fsdp2 or megatron. must be set.
strategy: ???
@ -38,6 +41,9 @@ clip_ratio_high: 0.2
# policy loss config
policy_loss:
# # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.PolicyLossConfig
# Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617
loss_mode: "vanilla"
@ -87,6 +93,9 @@ shuffle: false
# checkpoint configs
checkpoint:
# Target dataclass for this configuration
_target_: verl.trainer.config.CheckpointConfig
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
@ -95,6 +104,9 @@ checkpoint:
# .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg
load_contents: ${.save_contents}
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
async_save: False
# optimizer configs
optim:
@ -109,3 +121,10 @@ optim:
# Weight decay
weight_decay: 0.01
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps: -1
# Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}

View File

@ -13,6 +13,9 @@ defaults:
# load the reference default config, then apply the fields in the current yaml
- _self_
# Target class for this configuration
_target_: verl.workers.config.FSDPActorConfig
# TODO(haibin.lin): switch to fsdp2
strategy: fsdp
@ -32,8 +35,8 @@ entropy_checkpointing: False
# optimizer configs
optim:
# Warmup steps; negative value delegates to lr_warmup_steps_ratio
lr_warmup_steps: -1
# Target class for this configuration
_target_: verl.workers.config.FSDPOptimizerConfig
# Minimum LR ratio for cosine schedule
min_lr_ratio: 0.0
@ -47,6 +50,9 @@ optim:
# configs for FSDP
fsdp_config:
# Target class for this configuration
_target_: verl.workers.config.FSDPEngineConfig
# policy for wrapping the model
wrap_policy:
@ -71,3 +77,6 @@ fsdp_config:
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
# before the current forward computation.
forward_prefetch: False
# Whether to remove padding tokens in inputs during training
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}

View File

@ -4,18 +4,16 @@ defaults:
# load the reference default config, then apply the fields in the current yaml
- _self_
_target_: verl.workers.config.McoreActorConfig
strategy: megatron
data_loader_seed: null
load_weight: True
checkpoint:
async_save: False
optim:
_target_: verl.workers.config.McoreOptimizerConfig
optimizer: adam
clip_grad: 1.0
@ -23,9 +21,6 @@ optim:
# initial learning rate for warmup, default to 0.0
lr_warmup_init: 0.0
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps: null
lr_decay_steps: null
# select from constant/linear/cosine/inverse_square_root
@ -47,10 +42,16 @@ optim:
megatron:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.McoreEngineConfig
# Whether to offload model parameters to CPU
param_offload: False
# Whether to offload gradients to CPU
grad_offload: False
# Whether to offload optimizer state to CPU
optimizer_offload: False
tensor_model_parallel_size: 1
@ -104,6 +105,7 @@ megatron:
# profile the actor model in `update_policy`
profile:
# turn it on when you want to profile the actor model
use_profile: False

View File

@ -13,10 +13,12 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from typing import Any, Optional
from verl.base_config import BaseConfig
__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig"]
@dataclass
class KLControlConfig(BaseConfig):
@ -31,29 +33,12 @@ class KLControlConfig(BaseConfig):
target_kl (float): Target KL divergence for adaptive controller.
"""
_frozen_fields = ["type", "kl_coef", "horizon", "target_kl"]
type: str = "fixed"
kl_coef: float = 0.001
horizon: int = 10000
target_kl: float = 0.1
@dataclass
class PFPPOConfig(BaseConfig):
"""Configuration for preference feedback PPO.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
reweight_method (str): Method for reweighting samples. Can be "pow", "max_min", or "max_random".
weight_pow (float): Power used for weight scaling in "pow" method.
"""
_frozen_fields = ["reweight_method", "weight_pow"]
reweight_method: str = "pow"
weight_pow: float = 2.0
@dataclass
class FilterGroupsConfig(BaseConfig):
"""Configuration for filter groups (used in DAPO and Entropy).
@ -66,8 +51,6 @@ class FilterGroupsConfig(BaseConfig):
max_num_gen_batches (int): Non-positive values mean no upper limit.
"""
_frozen_fields = ["enable", "metric", "max_num_gen_batches"]
enable: bool = False
metric: Optional[str] = None
max_num_gen_batches: int = 0
@ -88,20 +71,10 @@ class AlgoConfig(BaseConfig):
kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full".
kl_ctrl (KLControlConfig): KL control configuration.
use_pf_ppo (bool): Whether to enable preference feedback PPO.
pf_ppo (Optional[PFPPOConfig]): Preference feedback PPO settings.
pf_ppo (dict[str, Any]): Preference feedback PPO settings.
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
"""
_frozen_fields = [
"gamma",
"lam",
"adv_estimator",
"norm_adv_by_std_in_grpo",
"use_kl_in_reward",
"kl_penalty",
"use_pf_ppo",
]
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
@ -110,5 +83,5 @@ class AlgoConfig(BaseConfig):
kl_penalty: str = "kl"
kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig)
use_pf_ppo: bool = False
pf_ppo: Optional[PFPPOConfig] = None
pf_ppo: dict[str, Any] = field(default_factory=dict)
filter_groups: Optional[FilterGroupsConfig] = None

View File

@ -17,110 +17,63 @@ from typing import Any, Optional
from verl.base_config import BaseConfig
__all__ = ["CheckpointConfig", "ProfileConfig", "BaseModelConfig"]
@dataclass
class CriticConfig(BaseConfig):
"""Configuration for critic model training.
class CheckpointConfig(BaseConfig):
"""Configuration for model checkpointing.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
rollout_n (int): Number of rollouts per update (mirrors actor rollout_n).
strategy (str): Strategy used for critic model training (fsdp, fsdp2, megatron).
optim (Dict[str, Any]): Optimizer configuration including lr, weight_decay, etc.
model (Dict[str, Any]): Model configuration including path, tokenizer_path, etc.
ppo_mini_batch_size (int): PPO mini-batch size per update.
ppo_micro_batch_size (Optional[int]): Global micro batch size (deprecated).
ppo_micro_batch_size_per_gpu (Optional[int]): Local per-GPU micro batch size.
use_dynamic_bsz (bool): Whether to automatically adjust batch size at runtime.
ppo_max_token_len_per_gpu (int): Max tokens per GPU in one PPO batch.
forward_max_token_len_per_gpu (int): Max token length per GPU in forward pass.
ppo_epochs (int): Number of PPO epochs per batch.
shuffle (bool): Shuffle training data across PPO epochs.
cliprange_value (float): PPO value function clipping range.
loss_agg_mode (str): Loss aggregation mode.
checkpoint (Dict[str, Any]): Checkpoint configuration.
profiler (Dict[str, Any]): Profiler configuration.
save_contents (list[str]): What to include in saved checkpoints.
Options: 'model', 'optimizer', 'extra', 'hf_model'.
load_contents (list[str]): Contents to load from checkpoint. Defaults to same as save_contents.
async_save (bool): Whether to save checkpoints asynchronously. Only implemented for Megatron as of now.
"""
# For legacy reason configs related to batch_size are mutated in each role
# In the future they will be added to frozen fields instead
_frozen_fields = [
"rollout_n",
"strategy",
"use_dynamic_bsz",
"ppo_max_token_len_per_gpu",
"forward_max_token_len_per_gpu",
"ppo_epochs",
"shuffle",
"cliprange_value",
"loss_agg_mode",
]
rollout_n: int = 1
strategy: str = "fsdp"
optim: dict[str, Any] = field(default_factory=dict)
model: dict[str, Any] = field(default_factory=dict)
ppo_mini_batch_size: int = 1
ppo_micro_batch_size: Optional[int] = None
ppo_micro_batch_size_per_gpu: Optional[int] = None
use_dynamic_bsz: bool = False
ppo_max_token_len_per_gpu: int = 32768
forward_max_token_len_per_gpu: int = 32768
ppo_epochs: int = 1
shuffle: bool = True
cliprange_value: float = 0.5
loss_agg_mode: str = "token-mean"
checkpoint: dict[str, Any] = field(default_factory=dict)
profiler: dict[str, Any] = field(default_factory=dict)
save_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
load_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
async_save: bool = False
@dataclass
class MegatronCriticConfig(CriticConfig):
"""Configuration for Megatron-based critic model training.
class ProfileConfig(BaseConfig):
"""Configuration for profiling.
The inheritance from CriticConfig provides all base critic configuration plus Megatron-specific settings.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
nccl_timeout (int): NCCL timeout in seconds for distributed operations.
megatron (Dict[str, Any]): Megatron-specific parallelism settings.
load_weight (bool): Whether to load initial weights.
data_loader_seed (Optional[int]): Seed for data loader.
use_profile (bool): Whether to enable profiling.
profile_ranks (Optional[list[int]]): List of ranks to profile. None means all ranks.
step_start (int): Starting step for profiling.
step_end (int): Ending step for profiling.
save_path (Optional[str]): Path to save profiling results.
"""
_frozen_fields = CriticConfig._frozen_fields + [
"nccl_timeout",
"load_weight",
"data_loader_seed",
]
strategy: str = "megatron"
nccl_timeout: int = 600
megatron: dict[str, Any] = field(default_factory=dict)
load_weight: bool = True
data_loader_seed: Optional[int] = None
use_profile: bool = False
profile_ranks: Optional[list[int]] = None
step_start: int = -1
step_end: int = -1
save_path: Optional[str] = None
@dataclass
class FSDPCriticConfig(CriticConfig):
"""Configuration for FSDP-based critic model training.
The inheritance from CriticConfig provides all base critic configuration plus FSDP-specific settings.
class BaseModelConfig(BaseConfig):
"""Base configuration for a model.
Contains core settings for loading and initializing a pretrained model checkpoint.
Args:
forward_micro_batch_size (int): Forward-only batch size during inference (global).
forward_micro_batch_size_per_gpu (int): Forward-only batch size during inference (per GPU).
ulysses_sequence_parallel_size (int): Sequence parallelism size for Ulysses-style model parallelism.
grad_clip (float): Gradient clipping for critic updates.
path (str): Path to pretrained model weights.
tokenizer_path (Optional[str]): Tokenizer path (defaults to actor's model path if not set).
override_config (dict): Hugging Face config override.
external_lib (Optional[str]): External model implementation (optional).
trust_remote_code (bool): Whether to trust remote code from Hugging Face models.
"""
_frozen_fields = CriticConfig._frozen_fields + [
"ulysses_sequence_parallel_size",
"grad_clip",
]
strategy: str = "fsdp"
forward_micro_batch_size: int = 1
forward_micro_batch_size_per_gpu: int = 1
ulysses_sequence_parallel_size: int = 1
grad_clip: float = 1.0
path: str = "~/models/deepseek-llm-7b-chat"
tokenizer_path: Optional[str] = None
override_config: dict[str, Any] = field(default_factory=dict)
external_lib: Optional[str] = None
trust_remote_code: bool = False

View File

@ -1,12 +1,23 @@
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.CriticConfig
# Number of rollouts per update (mirrors actor rollout_n)
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
# fsdp or fsdp2 strategy used for critic model training
strategy: ???
# whether to enable the critic worker.
# by default it is only enabled if advantage estimator is gae
# set it to True manually if you always want to enable critic worker
enable: null
# optimizer configs
optim:
# Learning rate
lr: 1e-5
# Warmup steps ratio; total steps will be injected at runtime
lr_warmup_steps_ratio: 0.0
@ -16,6 +27,10 @@ optim:
# Weight decay
weight_decay: 0.01
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps: -1
# model config for the critic
model:
@ -67,6 +82,9 @@ loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
# checkpoint configs
checkpoint:
# Target dataclass for this configuration
_target_: verl.trainer.config.CheckpointConfig
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
@ -74,11 +92,14 @@ checkpoint:
# What to include when loading checkpoints
load_contents: ${.save_contents}
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
async_save: False
# profiler configs
# the corresponding dataclass is verl.utils.profiler.ProfilerConfig.
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
# True for each task has its own database, False for all tasks in one training step share one database.
@ -89,6 +110,3 @@ profiler:
# The ranks that will be profiled. [] or [0,1,...]
ranks: []
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
_target_: verl.trainer.config.CriticConfig

View File

@ -13,13 +13,17 @@ defaults:
# load the reference default config, then apply the fields in the current yaml
- _self_
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.FSDPCriticConfig
# distribution strategy. Options: fsdp (deprecating), fsdp2
strategy: fsdp
# optimizer configs
optim:
# Learning rate
lr: 1e-5
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.FSDPOptimizerConfig
# Minimum LR ratio for cosine schedule
min_lr_ratio: null
@ -30,6 +34,9 @@ optim:
# model config for the critic
model:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.FSDPCriticModelCfg
# Whether to use shared memory for loading the model
use_shm: False
@ -45,6 +52,9 @@ model:
# FSDP-specific config
fsdp_config:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.FSDPEngineConfig
# Whether to offload model parameters to CPU
param_offload: False
@ -90,6 +100,3 @@ ulysses_sequence_parallel_size: 1
# Gradient clipping for critic updates
grad_clip: 1.0
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
_target_: verl.trainer.config.FSDPCriticConfig

View File

@ -7,6 +7,9 @@ defaults:
# load the reference default config, then apply the fields in the current yaml
- _self_
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.McoreCriticConfig
strategy: megatron
# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
@ -15,21 +18,18 @@ nccl_timeout: 600
# optimizer configs
optim:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.McoreOptimizerConfig
# select optimizer, default is Adam
optimizer: adam
# Learning rate
lr: 1e-6
# Clip gradients norm
clip_grad: 1.0
# initial learning rate for warmup, default to 0.0
lr_warmup_init: 0.0
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps: null
lr_decay_steps: null
# select from constant/linear/cosine/inverse_square_root
@ -53,6 +53,9 @@ optim:
# model config for the critic
model:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.BaseModelConfig
# override default empty mapping
override_config:
@ -65,6 +68,9 @@ model:
# megatron-specific parallelism settings
megatron:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.McoreEngineConfig
# Whether to offload model parameters to CPU
param_offload: False
@ -121,10 +127,3 @@ load_weight: True
# seed for data loader
data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}
# Asynchronous checkpoint saving
checkpoint:
async_save: False
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
_target_: verl.trainer.config.MegatronCriticConfig

View File

@ -68,7 +68,7 @@ custom_reward_function:
name: compute_score
algorithm:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.AlgoConfig
gamma: 1.0
lam: 1.0
@ -77,7 +77,7 @@ algorithm:
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.KLControlConfig
type: fixed
kl_coef: 0.001
@ -85,8 +85,6 @@ algorithm:
target_kl: 0.1
use_pf_ppo: False
pf_ppo:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
_target_: verl.trainer.config.PFPPOConfig
reweight_method: pow # ["pow", "max_min", "max_random"]
weight_pow: 2.0

View File

@ -112,7 +112,7 @@ actor_rollout_ref:
# profiler configs
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.ProfilerConfig
# True for each task has its own database, False for all tasks in one training step share one database.
@ -137,7 +137,7 @@ custom_reward_function:
# config for the algorithm
algorithm:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.AlgoConfig
# Discount factor for future rewards
@ -161,7 +161,7 @@ algorithm:
# KL control configuration
kl_ctrl:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.trainer.config.KLControlConfig
# KL control type: "fixed" or "adaptive"
@ -182,9 +182,6 @@ algorithm:
# Preference feedback PPO settings
pf_ppo:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
_target_: verl.trainer.config.PFPPOConfig
# Method for reweighting samples: "pow", "max_min", or "max_random"
reweight_method: pow

View File

@ -10,6 +10,15 @@ defaults:
# config for FSDP strategy
fsdp_config:
# Target class for this configuration
_target_: verl.workers.config.FSDPEngineConfig
# the wrap policy for FSDP model
wrap_policy:
# minimum number of params in a wrapped module
min_num_params: 0
# whether to offload parameters in FSDP
param_offload: False
@ -21,12 +30,6 @@ fsdp_config:
# before the current forward computation.
forward_prefetch: False
# the wrap policy for FSDP model
wrap_policy:
# minimum number of params in a wrapped module
min_num_params: 0
# sequence parallel size
# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}

View File

@ -7,45 +7,27 @@ defaults:
strategy: megatron
megatron:
_target_: verl.workers.config.MegatronEngineConfig
param_offload: False
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: None
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: False
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}
use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}
profile:
use_profile: False
profile_ranks: null
step_start: -1
step_end: -1
save_path: null
load_weight: True

View File

@ -29,8 +29,12 @@ model:
# FSDP-specific config
fsdp_config:
# Target configuration dataclass
_target_: verl.workers.config.FSDPEngineConfig
# Policy for wrapping layers with FSDP
wrap_policy:
# Minimum number of parameters to trigger wrapping
min_num_params: 0

View File

@ -15,6 +15,10 @@ nccl_timeout: 600
# Megatron parallelism & checkpointing config
megatron:
# Target configuration dataclass
_target_: verl.workers.config.MegatronEngineConfig
# Whether to offload model parameters to CPU
param_offload: False

View File

@ -12,8 +12,8 @@ strategy: ???
# model config for reward scoring
model:
# Input tokenizer. If the reward models chat template is inconsistent with the policy,
# we need to first decode to plaintext, then apply the rms chat_template.
# Input tokenizer. If the reward model's chat template is inconsistent with the policy,
# we need to first decode to plaintext, then apply the rm's chat_template.
# Then score with RM. If chat_templates are consistent, it can be set to null.
# set this to null if the chat template is identical
input_tokenizer: ${actor_rollout_ref.model.path}
@ -68,7 +68,7 @@ sandbox_fusion:
# profiler configs
profiler:
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
# hint for the target config dataclass
_target_: verl.utils.profiler.ProfilerConfig
# True for each task has its own database, False for all tasks in one training step share one database.

View File

@ -58,7 +58,6 @@ trainer:
total_training_steps: null
logger: [ 'console', 'wandb' ]
seed: 1
save_freq: -1
test_freq: -1
nnodes: 1

View File

@ -21,6 +21,7 @@ This trainer supports model-agonistic model initialization with huggingface
import json
import os
import uuid
import warnings
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
@ -53,6 +54,7 @@ from verl.trainer.ppo.metric_utils import (
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
from verl.utils.metric import (
reduce_metrics,
@ -256,8 +258,8 @@ def compute_advantage(
if config.get("use_pf_ppo", False):
data = core_algos.compute_pf_ppo_reweight_data(
data,
config.pf_ppo.reweight_method,
config.pf_ppo.weight_pow,
config.pf_ppo.get("reweight_method"),
config.pf_ppo.get("weight_pow"),
)
elif adv_estimator == AdvantageEstimator.GRPO:
# Initialize the mask for GRPO calculation
@ -369,21 +371,17 @@ class RayPPOTrainer:
if self.config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
if config.critic.enable is not None:
self.use_critic = bool(config.critic.enable)
elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
elif self.config.algorithm.adv_estimator in [
AdvantageEstimator.GRPO,
AdvantageEstimator.GRPO_PASSK,
AdvantageEstimator.REINFORCE_PLUS_PLUS,
AdvantageEstimator.REMAX,
AdvantageEstimator.RLOO,
AdvantageEstimator.OPO,
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
AdvantageEstimator.GPG,
]:
self.use_critic = False
else:
raise NotImplementedError
warnings.warn(
"Disabled critic as algorithm.adv_estimator != gae. "
"If it is not intended, please set critic.enable=True",
stacklevel=2,
)
self.use_critic = False
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
@ -434,8 +432,6 @@ class RayPPOTrainer:
ValueError: If both parameters are set or neither is set.
"""
settings = {
"actor_rollout_ref.actor": "micro_batch_size",
"critic": "micro_batch_size",
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
@ -456,14 +452,11 @@ class RayPPOTrainer:
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.actor.ppo_micro_batch_size,
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
"actor_rollout_ref.actor",
)
# Actor validation done in ActorConfig.__post_init__ and validate()
actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if self.use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
@ -479,66 +472,19 @@ class RayPPOTrainer:
"actor_rollout_ref.rollout",
)
if self.use_critic and not config.critic.use_dynamic_bsz:
# Check for critic micro-batch size conflicts
check_mutually_exclusive(
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
)
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)
# Actor
# check if train_batch_size is larger than ppo_mini_batch_size
# if NOT dynamic_bsz, we must ensure:
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert (
config.actor_rollout_ref.actor.ppo_mini_batch_size
% config.actor_rollout_ref.actor.ppo_micro_batch_size
== 0
)
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
assert config.actor_rollout_ref.actor.loss_agg_mode in [
"token-mean",
"seq-mean-token-sum",
"seq-mean-token-mean",
"seq-mean-token-sum-norm",
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print("NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if self.use_critic and not config.critic.use_dynamic_bsz:
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
if config.critic.ppo_micro_batch_size is not None:
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"} and (
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
):
assert config.actor_rollout_ref.model.use_remove_padding, (
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
)
if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
assert config.critic.model.use_remove_padding, (
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
)
if self.use_critic:
critic_config = omega_conf_to_dataclass(config.critic)
critic_config.validate(n_gpus, config.data.train_batch_size)
if config.data.get("val_batch_size", None) is not None:
print(
@ -850,7 +796,8 @@ class RayPPOTrainer:
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
critic_cfg = omega_conf_to_dataclass(self.config.critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
# create reference policy if needed

View File

@ -22,17 +22,17 @@ import torch.distributed
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer, ProcessorMixin
from verl.trainer.config import CheckpointConfig
from verl.utils.device import get_device_name, get_torch_device
class BaseCheckpointManager:
"""
A checkpoint manager that saves and loads
A checkpoint manager that saves and loads the following states in a SPMD way:
- model
- optimizer
- lr_scheduler
- extra_states
in a SPMD way.
We save
- sharded model states and optimizer states
@ -46,7 +46,7 @@ class BaseCheckpointManager:
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
checkpoint_config: DictConfig = None,
checkpoint_config: DictConfig | CheckpointConfig = None,
):
self.checkpoint_config = checkpoint_config
checkpoint_load_contents = checkpoint_config.get("load_contents", None) if checkpoint_config else None

View File

@ -41,8 +41,9 @@ def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[
if dataclass_type is None:
assert "_target_" in config, (
"When dataclass_type is not provided, config must contain _target_."
"See trainer/config/ppo_trainer.yaml algorithm section for an example."
"When dataclass_type is not provided, config must contain _target_. "
"See trainer/config/ppo_trainer.yaml algorithm section for an example. "
f"Got config: {config}"
)
from hydra.utils import instantiate
@ -51,6 +52,9 @@ def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[
if not is_dataclass(dataclass_type):
raise ValueError(f"{dataclass_type} must be a dataclass")
cfg = OmegaConf.create(config) # in case it's a dict
# pop _target_ to avoid hydra instantiate error, as most dataclass do not have _target_
if "_target_" in cfg:
cfg.pop("_target_")
cfg_from_dataclass = OmegaConf.structured(dataclass_type)
# let cfg override the existing vals in `cfg_from_dataclass`
cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg)

View File

@ -42,23 +42,25 @@ def get_megatron_optimizer_param_scheduler(
"""
Get the optimizer parameter scheduler for Megatron.
"""
lr_decay_steps = config.lr_decay_steps
lr_warmup_steps = config.lr_warmup_steps
if config.get("lr_decay_steps", None) is None:
config.lr_decay_steps = config.total_training_steps
lr_decay_steps = config.total_training_steps
wsd_decay_steps = None
if config.get("lr_wsd_decay_steps", None) is not None:
wsd_decay_steps = config.lr_wsd_decay_steps
if config.get("lr_warmup_steps_ratio", None) is not None and (
config.get("lr_warmup_steps", None) is None or config.lr_warmup_steps <= 0
):
config.lr_warmup_steps = int(config.lr_warmup_steps_ratio * config.lr_decay_steps)
lr_warmup_steps = int(config.lr_warmup_steps_ratio * lr_decay_steps)
opt_param_scheduler = OptimizerParamScheduler(
optimizer,
init_lr=config.lr_warmup_init,
max_lr=config.lr,
min_lr=config.min_lr,
lr_warmup_steps=config.lr_warmup_steps,
lr_decay_steps=config.lr_decay_steps,
lr_warmup_steps=lr_warmup_steps,
lr_decay_steps=lr_decay_steps,
lr_decay_style=config.lr_decay_style,
start_wd=config.weight_decay,
end_wd=config.weight_decay,

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import ClassVar
from verl.base_config import BaseConfig
@ -31,13 +30,8 @@ class ProfilerConfig(BaseConfig):
ranks (list[int]): The ranks that will be profiled. Defaults to [].
"""
# the fields expected to be frozen
_frozen_fields: ClassVar[set[str]] = {"discrete", "all_ranks", "ranks"}
discrete: bool = False
all_ranks: bool = False
ranks: list[int] = field(default_factory=list)
def union(self, other: "ProfilerConfig") -> "ProfilerConfig":

View File

@ -17,6 +17,7 @@ from typing import Callable, Optional
import torch
import torch.distributed
from omegaconf import DictConfig, OmegaConf
from .config import ProfilerConfig
@ -40,6 +41,8 @@ class Profiler:
def __init__(self, config):
# note : if we do not set use_profile, it will be set as None, so that all function will be skip
if not isinstance(config, DictConfig):
config = OmegaConf.create(config)
self.config = config
self.skip_prof = False
self.saved = False

View File

@ -35,6 +35,7 @@ from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_b
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
from verl.workers.actor import BasePPOActor
from verl.workers.config import ActorConfig
if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
@ -49,18 +50,27 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class DataParallelPPOActor(BasePPOActor):
def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):
"""FSDP DataParallel PPO Actor or Ref worker
Args:
config (ActorConfig): Actor config
actor_module (nn.Module): Actor or ref module
actor_optimizer (torch.optim.Optimizer, optional): Actor optimizer. Defaults to None.
"""
def __init__(self, config: ActorConfig, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):
"""When optimizer is None, it is Reference Policy"""
super().__init__(config)
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
role = "Ref" if actor_optimizer is None else "Actor"
self.use_remove_padding = self.config.get("use_remove_padding", False)
if torch.distributed.get_rank() == 0:
print(f"Actor use_remove_padding={self.use_remove_padding}")
print(f"{role} use_remove_padding={self.use_remove_padding}")
self.use_fused_kernels = self.config.get("use_fused_kernels", False)
if torch.distributed.get_rank() == 0:
print(f"Actor use_fused_kernels={self.use_fused_kernels}")
print(f"{role} use_fused_kernels={self.use_fused_kernels}")
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1

View File

@ -0,0 +1,21 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .critic import * # noqa
from .actor import * # noqa
from .engine import * # noqa
from .optimizer import * # noqa
from . import actor, critic, engine, optimizer
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__

View File

@ -0,0 +1,234 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from omegaconf import MISSING
from verl.base_config import BaseConfig
from verl.trainer.config import CheckpointConfig
from .engine import FSDPEngineConfig, McoreEngineConfig
from .optimizer import OptimizerConfig
__all__ = ["PolicyLossConfig", "ActorConfig", "FSDPActorConfig", "McoreActorConfig"]
@dataclass
class PolicyLossConfig(BaseConfig):
"""Configuration for policy loss computation.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
loss_mode (str): Loss function mode. Options: 'vanilla', 'clip-cov', 'kl-cov', 'gpg'.
clip_cov_ratio (float): Ratio of tokens to be clipped for clip-cov loss.
clip_cov_lb (float): Lower bound for clip-cov loss.
clip_cov_ub (float): Upper bound for clip-cov loss.
kl_cov_ratio (float): Ratio of tokens to be applied KL penalty for kl-cov loss.
ppo_kl_coef (float): KL divergence penalty coefficient.
"""
loss_mode: str = "vanilla"
clip_cov_ratio: float = 0.0002
clip_cov_lb: float = 1.0
clip_cov_ub: float = 5.0
kl_cov_ratio: float = 0.0002
ppo_kl_coef: float = 0.1
@dataclass
class ActorConfig(BaseConfig):
"""Configuration for actor model training.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
strategy (str): Training strategy. Must be specified.
ppo_mini_batch_size (int): Mini-batch size for PPO training.
ppo_micro_batch_size (Optional[int]): Micro-batch size for PPO training.
If None, uses ppo_micro_batch_size_per_gpu.
ppo_micro_batch_size_per_gpu (Optional[int]): Micro-batch size per GPU for PPO training.
use_dynamic_bsz (bool): Whether to use dynamic batch sizing.
ppo_max_token_len_per_gpu (int): Maximum token length per GPU for PPO training.
clip_ratio (float): PPO clipping ratio for policy loss.
clip_ratio_low (float): Lower bound for PPO clipping ratio.
clip_ratio_high (float): Upper bound for PPO clipping ratio.
policy_loss (PolicyLossConfig): Configuration for policy loss computation.
clip_ratio_c (float): Clipping ratio for critic loss.
loss_agg_mode (str): Loss aggregation mode. Options: 'token-mean', 'sample-mean'.
entropy_coeff (float): Entropy coefficient for regularization.
use_kl_loss (bool): Whether to use KL divergence loss.
use_torch_compile (bool): Whether to use torch.compile for optimization.
kl_loss_coef (float): KL divergence loss coefficient.
kl_loss_type (str): Type of KL loss to use.
ppo_epochs (int): Number of PPO epochs per training step.
shuffle (bool): Whether to shuffle data during training.
checkpoint (CheckpointConfig): Configuration for checkpointing.
optim (OptimizerConfig): Configuration for optimizer.
use_fused_kernels (bool): Whether to use custom fused kernels (e.g., FlashAttention, fused MLP).
"""
_mutable_fields = BaseConfig._mutable_fields | {
"ppo_mini_batch_size",
"ppo_micro_batch_size",
"ppo_micro_batch_size_per_gpu",
}
strategy: str = MISSING
ppo_mini_batch_size: int = 256
ppo_micro_batch_size: Optional[int] = None
ppo_micro_batch_size_per_gpu: Optional[int] = None
use_dynamic_bsz: bool = False
ppo_max_token_len_per_gpu: int = 16384
clip_ratio: float = 0.2
clip_ratio_low: float = 0.2
clip_ratio_high: float = 0.2
policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)
clip_ratio_c: float = 3.0
loss_agg_mode: str = "token-mean"
entropy_coeff: float = 0
use_kl_loss: bool = False
use_torch_compile: bool = True
kl_loss_coef: float = 0.001
kl_loss_type: str = "low_var_kl"
ppo_epochs: int = 1
shuffle: bool = False
checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
optim: OptimizerConfig = field(default_factory=OptimizerConfig)
use_fused_kernels: bool = False
def __post_init__(self):
"""Validate actor configuration parameters."""
assert self.strategy != MISSING
if not self.use_dynamic_bsz:
if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None:
raise ValueError(
"[actor] You have set both 'actor.ppo_micro_batch_size' AND 'actor.ppo_micro_batch_size_per_gpu'. "
"Please remove 'actor.ppo_micro_batch_size' because only '*_ppo_micro_batch_size_per_gpu' is "
"supported (the former is deprecated)."
)
else:
assert not (self.ppo_micro_batch_size is None and self.ppo_micro_batch_size_per_gpu is None), (
"[actor] Please set at least one of 'actor.ppo_micro_batch_size' or "
"'actor.ppo_micro_batch_size_per_gpu' if use_dynamic_bsz is not enabled."
)
valid_loss_agg_modes = [
"token-mean",
"seq-mean-token-sum",
"seq-mean-token-mean",
"seq-mean-token-sum-norm",
]
if self.loss_agg_mode not in valid_loss_agg_modes:
raise ValueError(f"Invalid loss_agg_mode: {self.loss_agg_mode}")
def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):
"""Validate actor configuration with runtime parameters."""
if not self.use_dynamic_bsz:
if train_batch_size < self.ppo_mini_batch_size:
raise ValueError(
f"train_batch_size ({train_batch_size}) must be >= "
f"actor.ppo_mini_batch_size ({self.ppo_mini_batch_size})"
)
sp_size = getattr(self, "ulysses_sequence_parallel_size", 1)
if self.ppo_micro_batch_size is not None:
if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0:
raise ValueError(
f"ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by "
f"ppo_micro_batch_size ({self.ppo_micro_batch_size})"
)
if self.ppo_micro_batch_size * sp_size < n_gpus:
raise ValueError(
f"ppo_micro_batch_size ({self.ppo_micro_batch_size}) * "
f"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})"
)
@staticmethod
def _check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
"""Validate mutually exclusive micro batch size configuration options."""
param = "ppo_micro_batch_size"
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)
@dataclass
class McoreActorConfig(ActorConfig):
"""Configuration for Megatron actor models.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
strategy (str): Training strategy set to 'megatron' for Megatron parallelism.
data_loader_seed (Optional[int]): Seed for data loader. If None, uses global seed.
load_weight (bool): Whether to load model weights from checkpoint.
megatron (dict[str, Any]): Configuration for Megatron parallelism settings.
profile (dict[str, Any]): Configuration for profiling settings.
"""
strategy: str = "megatron"
data_loader_seed: Optional[int] = None
load_weight: bool = True
megatron: McoreEngineConfig = field(default_factory=McoreEngineConfig)
profile: dict[str, Any] = field(default_factory=dict)
@dataclass
class FSDPActorConfig(ActorConfig):
"""Configuration for FSDP actor models.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
strategy (str): Training strategy set to 'fsdp' for Fully Sharded Data Parallel.
grad_clip (float): Gradient clipping threshold.
ulysses_sequence_parallel_size (int): Ulysses sequence parallel size for long sequences.
entropy_from_logits_with_chunking (bool): Whether to compute entropy from logits
with chunking for memory efficiency.
entropy_checkpointing (bool): Whether to use gradient checkpointing for entropy computation.
fsdp_config (dict[str, Any]): Configuration for FSDP settings.
use_remove_padding (bool): Whether to remove padding tokens in inputs during training
"""
strategy: str = "fsdp"
grad_clip: float = 1.0
ulysses_sequence_parallel_size: int = 1
entropy_from_logits_with_chunking: bool = False
entropy_checkpointing: bool = False
fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
use_remove_padding: bool = False
def __post_init__(self):
"""Validate FSDP actor configuration parameters."""
super().__post_init__()
def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):
"""Validate FSDP actor configuration with runtime parameters."""
super().validate(n_gpus, train_batch_size, model_config)
if self.strategy in {"fsdp", "fsdp2"} and self.ulysses_sequence_parallel_size > 1:
if model_config and not model_config.get("use_remove_padding", False):
raise ValueError(
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
)

View File

@ -0,0 +1,231 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import MISSING
from verl.base_config import BaseConfig
from verl.trainer.config import BaseModelConfig, CheckpointConfig
from verl.utils.profiler import ProfilerConfig
from .engine import FSDPEngineConfig, McoreEngineConfig
from .optimizer import OptimizerConfig
__all__ = ["CriticConfig", "FSDPCriticConfig", "McoreCriticConfig", "FSDPCriticModelCfg"]
@dataclass
class CriticConfig(BaseConfig):
"""Configuration for critic model training.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
strategy (str): Strategy used for critic model training (fsdp, fsdp2, megatron).
ppo_micro_batch_size_per_gpu (int): Local per-GPU micro batch size.
rollout_n (int): Number of rollouts per update (mirrors actor rollout_n).
optim (Dict[str, Any]): Optimizer configuration including lr, weight_decay, etc.
model (Dict[str, Any]): Model configuration including path, tokenizer_path, etc.
ppo_mini_batch_size (int): PPO mini-batch size per update.
ppo_micro_batch_size (Optional[int]): Global micro batch size (deprecated).
use_dynamic_bsz (bool): Whether to automatically adjust batch size at runtime.
ppo_max_token_len_per_gpu (int): Max tokens per GPU in one PPO batch.
forward_max_token_len_per_gpu (int): Max token length per GPU in forward pass.
ppo_epochs (int): Number of PPO epochs per batch.
shuffle (bool): Shuffle training data across PPO epochs.
cliprange_value (float): PPO value function clipping range.
loss_agg_mode (str): Loss aggregation mode.
checkpoint (Dict[str, Any]): Checkpoint configuration.
profiler (Dict[str, Any]): Profiler configuration.
enable (Optional[bool]): Whether to enable the critic.
"""
_mutable_fields = BaseConfig._mutable_fields | {
"ppo_micro_batch_size_per_gpu",
"ppo_mini_batch_size",
"ppo_micro_batch_size",
}
strategy: str = MISSING
ppo_micro_batch_size_per_gpu: Optional[int] = None
enable: Optional[bool] = None
rollout_n: int = 1
ppo_mini_batch_size: int = 1
use_dynamic_bsz: bool = False
ppo_max_token_len_per_gpu: int = 32768
forward_max_token_len_per_gpu: int = 32768
ppo_epochs: int = 1
shuffle: bool = True
cliprange_value: float = 0.5
loss_agg_mode: str = "token-mean"
ppo_micro_batch_size: Optional[int] = None
optim: OptimizerConfig = field(default_factory=OptimizerConfig)
model: BaseModelConfig = field(default_factory=BaseModelConfig)
checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)
def __post_init__(self):
"""Validate critic configuration parameters."""
assert self.strategy != MISSING
if not self.use_dynamic_bsz:
self._check_mutually_exclusive(self.ppo_micro_batch_size, self.ppo_micro_batch_size_per_gpu, "critic")
if self.ppo_micro_batch_size is not None:
if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0:
raise ValueError(
f"[critic] ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by "
f"ppo_micro_batch_size ({self.ppo_micro_batch_size})"
)
def validate(self, n_gpus: int, train_batch_size: int):
"""Validate critic configuration with runtime parameters.
Args:
n_gpus: Total number of GPUs available
train_batch_size: Training batch size from data config
"""
if not self.use_dynamic_bsz:
if train_batch_size < self.ppo_mini_batch_size:
raise ValueError(
f"train_batch_size ({train_batch_size}) must be >= "
f"critic.ppo_mini_batch_size ({self.ppo_mini_batch_size})"
)
@staticmethod
def _check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
"""Validate mutually exclusive micro batch size configuration options.
Ensures that users don't set both deprecated micro_batch_size and
the new micro_batch_size_per_gpu parameters simultaneously.
Args:
mbs: Deprecated micro batch size parameter value.
mbs_per_gpu: New micro batch size per GPU parameter value.
name (str): Configuration section name for error messages.
Raises:
ValueError: If both parameters are set or neither is set.
"""
param = "micro_batch_size"
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)
@dataclass
class McoreCriticConfig(CriticConfig):
"""Configuration for Megatron-based critic model training.
The inheritance from CriticConfig provides all base critic configuration plus Megatron-specific settings.
Args:
nccl_timeout (int): NCCL timeout in seconds for distributed operations.
megatron (Dict[str, Any]): Megatron-specific parallelism settings.
load_weight (bool): Whether to load initial weights.
data_loader_seed (Optional[int]): Seed for data loader.
"""
strategy: str = "megatron"
nccl_timeout: int = 600
megatron: McoreEngineConfig = field(default_factory=McoreEngineConfig)
load_weight: bool = True
data_loader_seed: Optional[int] = None
def validate(self, n_gpus: int, train_batch_size: int):
"""Validate Megatron critic configuration with runtime parameters."""
super().validate(n_gpus, train_batch_size)
@dataclass
class FSDPCriticConfig(CriticConfig):
"""Configuration for FSDP-based critic model training.
The inheritance from CriticConfig provides all base critic configuration plus FSDP-specific settings.
Args:
forward_micro_batch_size (int): Forward-only batch size during inference (global).
forward_micro_batch_size_per_gpu (int): Forward-only batch size during inference (per GPU).
ulysses_sequence_parallel_size (int): Sequence parallelism size for Ulysses-style model parallelism.
grad_clip (float): Gradient clipping for critic updates.
"""
_mutable_fields = CriticConfig._mutable_fields | {
"forward_micro_batch_size",
"forward_micro_batch_size_per_gpu",
}
strategy: str = "fsdp"
forward_micro_batch_size: int = 1
forward_micro_batch_size_per_gpu: int = 1
ulysses_sequence_parallel_size: int = 1
grad_clip: float = 1.0
def __post_init__(self):
"""Validate FSDP critic configuration parameters."""
super().__post_init__()
if self.strategy in {"fsdp", "fsdp2"}:
if self.ulysses_sequence_parallel_size > 1:
if not self.model.get("use_remove_padding", False):
raise ValueError(
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
)
def validate(self, n_gpus: int, train_batch_size: int):
"""Validate FSDP critic configuration with runtime parameters."""
super().validate(n_gpus, train_batch_size)
if not self.use_dynamic_bsz:
sp_size = self.ulysses_sequence_parallel_size
if self.ppo_micro_batch_size is not None:
if self.ppo_micro_batch_size * sp_size < n_gpus:
raise ValueError(
f"critic.ppo_micro_batch_size ({self.ppo_micro_batch_size}) * "
f"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})"
)
@dataclass
class FSDPCriticModelCfg(BaseModelConfig):
"""FSDP-enabled critic model configuration.
Inherits base critic settings and adds distributed-memory and LoRA options.
Args:
use_shm (bool): Whether to use shared memory for loading the model.
enable_activation_offload (bool): Offload activations to CPU to reduce GPU memory usage.
use_remove_padding (bool): Use remove-padding optimization (saves compute).
enable_gradient_checkpointing (bool): Enable gradient checkpointing for memory efficiency.
fsdp_config (FSDPEngineConfig): FSDP-specific configuration block.
lora_rank (int): Set to positive value to enable LoRA (e.g., 32).
lora_alpha (int): LoRA scaling factor.
target_modules (Union[str, List[str]]): LoRA target modules: "all-linear" or list of layer names.
"""
use_shm: bool = False
enable_activation_offload: bool = False
use_remove_padding: bool = False
enable_gradient_checkpointing: bool = True
fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
lora_rank: int = 0
lora_alpha: int = 16
target_modules: str | list[str] = "all-linear"

View File

@ -0,0 +1,105 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from typing import Any, Optional
from verl.base_config import BaseConfig
__all__ = ["FSDPEngineConfig", "McoreEngineConfig"]
@dataclass
class McoreEngineConfig(BaseConfig):
"""Configuration for Megatron parallelism.
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
param_offload (bool): Whether to offload parameters to CPU.
grad_offload (bool): Whether to offload gradients to CPU.
optimizer_offload (bool): Whether to offload optimizer states to CPU.
tensor_model_parallel_size (int): Tensor model parallel size.
expert_model_parallel_size (int): Expert model parallel size for MoE models.
expert_tensor_parallel_size (Optional[int]): Expert tensor parallel size for MoE models.
pipeline_model_parallel_size (int): Pipeline model parallel size.
virtual_pipeline_model_parallel_size (Optional[int]): Virtual pipeline model parallel size
for interleaved scheduling.
context_parallel_size (int): Context parallel size for long sequences.
sequence_parallel (bool): Whether to enable sequence parallelism.
use_distributed_optimizer (bool): Whether to use distributed optimizer.
use_dist_checkpointing (bool): Whether to use distributed checkpointing.
dist_checkpointing_path (Optional[str]): Path for distributed checkpointing.
seed (int): Random seed for reproducibility.
override_ddp_config (dict[str, Any]): Override configuration for DDP.
override_transformer_config (dict[str, Any]): Override configuration for transformer.
use_mbridge (bool): Whether to use MBridge for communication.
"""
# sequence_parallel is not listed as a frozen field for auto-correction purpose
_mutable_fields = BaseConfig._mutable_fields | {"sequence_parallel"}
param_offload: bool = False
grad_offload: bool = False
optimizer_offload: bool = False
tensor_model_parallel_size: int = 1
expert_model_parallel_size: int = 1
expert_tensor_parallel_size: Optional[int] = None
pipeline_model_parallel_size: int = 1
virtual_pipeline_model_parallel_size: Optional[int] = None
context_parallel_size: int = 1
sequence_parallel: bool = True
use_distributed_optimizer: bool = True
use_dist_checkpointing: bool = False
dist_checkpointing_path: Optional[str] = None
seed: int = 42
override_ddp_config: dict[str, Any] = field(default_factory=dict)
override_transformer_config: dict[str, Any] = field(default_factory=dict)
use_mbridge: bool = False
def __post_init__(self) -> None:
"""config validation logics go here"""
if self.tensor_model_parallel_size == 1:
warnings.warn("set sequence parallel to false as TP size is 1", stacklevel=2)
self.sequence_parallel = False
@dataclass
class FSDPEngineConfig(BaseConfig):
"""Configuration for FSDP (Fully Sharded Data Parallel).
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Args:
wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy.
param_offload (bool): Whether to offload parameters to CPU, default False
optimizer_offload (bool): Whether to offload optimizer states to CPU, default False
offload_policy (bool): Whether to offload policy model parameters, default False
reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True
fsdp_size (int): FSDP group size. -1 means use all available GPUs.
forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False
model_dtype (str): Model data type used to initialize the transformers model. default "fp32"
use_orig_params (bool): Whether to use original parameters when initialize FSDP1, default False
"""
wrap_policy: dict[str, Any] = field(default_factory=dict)
param_offload: bool = False
optimizer_offload: bool = False
offload_policy: bool = False
reshard_after_forward: bool = True
fsdp_size: int = -1
forward_prefetch: bool = False
model_dtype: str = "fp32"
use_orig_params: bool = False

View File

@ -0,0 +1,94 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from omegaconf import MISSING
from verl.base_config import BaseConfig
__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig"]
@dataclass
class OptimizerConfig(BaseConfig):
"""Base optimizer configuration.
Args:
lr (float): learning rate. Must be specified.
lr_warmup_steps_ratio (float): Warmup steps ratio; total steps will be injected at runtime.
total_training_steps (int): Total training steps (must be overridden at runtime).
weight_decay (float): Weight decay factor.
lr_warmup_steps (Optional[int]): Number of warmup steps; None delegates to lr_warmup_steps_ratio.
"""
lr: float = MISSING
lr_warmup_steps_ratio: float = 0.0
total_training_steps: int = -1
weight_decay: float = 0.01
lr_warmup_steps: Optional[int] = -1
def __post_init__(self):
assert self.lr != MISSING
@dataclass
class FSDPOptimizerConfig(OptimizerConfig):
"""FSDP optimizer configuration extending base OptimizerConfig.
Args:
lr (float): Learning rate.
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
warmup_style (str): LR warmup style: "constant" or "cosine".
num_cycles (float): Number of cosine cycles in LR schedule.
"""
min_lr_ratio: Optional[float] = None
warmup_style: str = "constant"
num_cycles: float = 0.5
def __post_init__(self):
assert self.warmup_style in ["constant", "cosine"]
return super().__post_init__()
@dataclass
class McoreOptimizerConfig(OptimizerConfig):
"""Mcore optimizer configuration extending base OptimizerConfig.
Args:
optimizer (str): Optimizer name; default is "adam".
lr (float): Learning rate.
clip_grad (float): Gradient clipping norm.
lr_warmup_init (float): Initial learning rate for warmup; defaults to 0.0.
lr_decay_steps (Optional[int]): Number of decay steps.
lr_decay_style (str): LR decay style: "constant", "linear", "cosine", or "inverse_square_root".
min_lr (float): Minimum learning rate.
weight_decay_incr_style (str): Weight decay increment style: "constant" or "cosine".
lr_wsd_decay_style (str): Weight-standard-deviation decay style: "constant", "exponential", or "cosine".
lr_wsd_decay_steps (Optional[int]): Number of steps for weight-standard-deviation decay.
use_checkpoint_opt_param_scheduler (bool): Whether to use checkpoint optimizer parameter scheduler.
"""
optimizer: str = "adam"
clip_grad: float = 1.0
lr_warmup_init: float = 0.0
lr_decay_steps: Optional[int] = None
lr_decay_style: str = "linear"
min_lr: float = 0.0
weight_decay_incr_style: str = "constant"
lr_wsd_decay_style: str = "exponential"
lr_wsd_decay_steps: Optional[int] = None
use_checkpoint_opt_param_scheduler: bool = False

View File

@ -84,9 +84,6 @@ class MegatronPPOCritic(BasePPOCritic):
assert config.get("ulysses_sequence_parallel_size", 1) == 1
if config.shuffle:
assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set"
if config.megatron.tensor_model_parallel_size == 1:
print("[Warining] Because critic tp size == 1, set sp to False")
config.megatron.sequence_parallel = False
self.config = config
@GPUMemoryLogger("megatron critic", logger=logger)

View File

@ -195,7 +195,7 @@ class FSDPEngine(BaseEngine):
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,

View File

@ -72,6 +72,7 @@ from verl.utils.model import compute_position_id_with_mask
from verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage, simple_timer
from verl.utils.profiler.performance import reduce_timing
from verl.utils.py_functional import convert_to_regular_types
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
logger = logging.getLogger(__file__)
@ -209,7 +210,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
def _build_model_optimizer(
self,
model_path,
fsdp_config,
fsdp_config: FSDPEngineConfig,
optim_config,
override_model_config,
use_remove_padding=False,
@ -378,8 +379,8 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh,
use_orig_params=self.config.actor.fsdp_config.get("use_orig_params", False),
forward_prefetch=self.config.actor.fsdp_config.get("forward_prefetch", False),
use_orig_params=fsdp_config.get("use_orig_params", False),
forward_prefetch=fsdp_config.get("forward_prefetch", False),
)
elif fsdp_strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
@ -566,8 +567,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_shm = self.config.model.get("use_shm", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
@ -576,10 +576,10 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config)
else:
optim_config = None
fsdp_config = OmegaConf.create()
fsdp_config = FSDPEngineConfig()
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
(
@ -614,12 +614,9 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
actor_cfg = omega_conf_to_dataclass(self.config.actor)
self.actor = DataParallelPPOActor(
config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
)
if self._is_rollout:
@ -631,7 +628,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
self.ref_module_fsdp = self._build_model_optimizer(
model_path=local_path,
fsdp_config=self.config.ref.fsdp_config,
fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config),
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
@ -916,18 +913,16 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
class CriticWorker(Worker, DistProfilerExtension):
def __init__(self, config):
def __init__(self, config: FSDPCriticConfig):
Worker.__init__(self)
DistProfilerExtension.__init__(
self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")))
)
DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler")))
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None)
)
self.config = config
self.config: FSDPCriticConfig = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
@ -996,8 +991,7 @@ class CriticWorker(Worker, DistProfilerExtension):
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
@ -1163,8 +1157,14 @@ class CriticWorker(Worker, DistProfilerExtension):
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
min_lr_ratio = config.optim.get("min_lr_ratio", 0.0)
num_cycles = config.optim.get("num_cycles", 0.5)
critic_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps
optimizer=critic_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")

View File

@ -26,7 +26,7 @@ import torch
import torch.distributed
from codetiming import Timer
from megatron.core import parallel_state as mpu
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf import DictConfig, OmegaConf
from verl import DataProto
from verl.single_controller.base.decorator import Dispatch, register
@ -53,6 +53,7 @@ from verl.utils.profiler import (
)
from verl.utils.profiler.performance import reduce_timing
from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.config import McoreCriticConfig
from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel
@ -202,7 +203,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
return parallel_model
override_ddp_config = OmegaConf.to_container(
self.config.actor.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {}))
)
return get_model(
megatron_actor_model_provider,
@ -390,14 +391,14 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
from verl.utils.torch_dtypes import PrecisionType
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
if self._is_actor:
override_transformer_config = OmegaConf.to_container(
self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {}))
)
elif self._is_ref:
override_transformer_config = OmegaConf.to_container(
self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
OmegaConf.create(self.config.ref.megatron.get("override_transformer_config", {}))
)
else:
override_transformer_config = {}
@ -427,12 +428,9 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
self.config.actor.use_fused_kernels = use_fused_kernels
actor_cfg = omega_conf_to_dataclass(self.config.actor)
self.actor = MegatronPPOActor(
config=self.config.actor,
config=actor_cfg,
model_config=self.actor_model_config,
hf_config=self.hf_config,
tf_config=self.tf_config,
@ -712,12 +710,10 @@ class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
class CriticWorker(MegatronWorker, DistProfilerExtension):
def __init__(self, config):
def __init__(self, config: McoreCriticConfig):
MegatronWorker.__init__(self)
DistProfilerExtension.__init__(
self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")))
)
self.config = config
DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler")))
self.config: McoreCriticConfig = config
# NOTE(sgm): We utilize colocate WorkerGroup by default.
# As a result, Workers for different model share the same process.
@ -807,7 +803,7 @@ class CriticWorker(MegatronWorker, DistProfilerExtension):
return parallel_model
override_ddp_config = OmegaConf.to_container(
self.config.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
OmegaConf.create(self.config.megatron.get("override_ddp_config", {}))
)
# Step 3: initialize the megatron model
critic_module = get_model(
@ -861,9 +857,9 @@ class CriticWorker(MegatronWorker, DistProfilerExtension):
import importlib
importlib.import_module(self.config.model.external_lib)
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_transformer_config = OmegaConf.to_container(
self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
OmegaConf.create(self.config.megatron.get("override_transformer_config", {}))
)
self.param_dtype = torch.bfloat16
self.dtype = PrecisionType.to_dtype(self.param_dtype)
@ -1108,9 +1104,9 @@ class RewardModelWorker(MegatronWorker, DistProfilerExtension):
import importlib
importlib.import_module(self.config.model.external_lib)
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_transformer_config = OmegaConf.to_container(
self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
OmegaConf.create(self.config.megatron.get("override_transformer_config", {}))
)
use_shm = self.config.model.get("use_shm", False)