mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[cfg] refactor: add ActorConfig, EngineConfig, and ActorWorker unit test, refactor validation code (#2621)
As initially mentioned in https://github.com/volcengine/verl/discussions/1941, having structured configuration classes in verl makes argument passing easier for testing and validation. This is an extended thread on the current implementation of configuration schema in verl. Related PRs: - https://github.com/volcengine/verl/pull/2117 - https://github.com/volcengine/verl/pull/2621 # Motivation By moving from loose `omegaconfig.DictConfig`-based parameters to structured dataclasses, we gain: - Type safety & IDE support when accessing fields (e.g. cfg.optim.lr). - Validation hooks via __post_init__ in each class. - Immutable defaults with controlled mutability (e.g., an extra field). - Seamless Hydra/OmegaConf integration and easy per-recipe extension. # Core: BaseConfig hydra natively provides support for converting DictConfig to dataclass, but dataclass does not support accessing attribute via `get()`. We introduce a base class to provide backward compatibility and make the change less abrupt for existing users. All config dataclasses inherit from BaseConfig, which: - Implements collections.abc.Mapping → dict-like iteration/access. - Freezes attributes once set, unless listed in _mutable_fields. - Provides an `extra: dict[str, Any]` for unchecked extensions. ```python @dataclass class BaseConfig(collections.abc.Mapping): """Dict-like, frozen dataclass with opt-in mutability.""" _mutable_fields: set[str] = {"extra"} extra: dict[str, Any] = field(default_factory=dict) def __setattr__(self, name: str, value): if name in self.__dict__ and name not in self._mutable_fields: raise FrozenInstanceError(f"Field '{name}' is frozen") super().__setattr__(name, value) # Mapping methods: get, __getitem__, __iter__, __len__ … ``` # Example Config Classes (verl/trainer/config) Each sub-component of the trainer has its own dataclass, inheriting BaseConfig. ```yaml: critic: checkpoint: _target_: verl.trainer.config.CheckpointConfig save_contents: ["model","optimizer","extra"] load_contents: ["model","optimizer","extra"] async_save: false ``` Definition: ```python @dataclass class CheckpointConfig(BaseConfig): """What to save/load and async behavior.""" save_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) load_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) async_save: bool = False def __post_init__(self): # validation checks go here after initialization ckpt_cfg = CheckpointConfig(async_save=True) print(ckpt_cfg.save_contents) print(ckpt_cfg.get("save_contents", default_value)) print(ckpt_cfg["save_contents"]) # converting hydra-generated omegaconf.DictConfig to the dataclass config: from verl.utils.config import omegaconf_to_dataclass ckpt_cfg_from_cli = omegaconf_to_dataclass(config.critic.checkpoint) ``` # Extending existing config classes Because now configs become structured, unexpected keys would raise exceptions. To add new keys, there are two ways: ## Explicit class extensions: ```python from verl.workers.config import FSDPActorConfig @dataclass class SPPOActorConfig(FSDPActorConfig): """Add SPPO-specific temperature/penalty.""" sppo_eta: float = 1.0 ``` When using yaml or from command line, update the target config class: ```yaml hydra: searchpath: - file://verl/trainer/config defaults: - ppo_trainer # base trainer config - _self_ # then apply these overrides actor_rollout_ref: actor: _target_: recipe.sppo.config.SPPOActorConfig # **new target dataclass required for extension ** sppo_eta: 1.0 ``` or directly from command line: ```bash python main_sppo.py \ actor_rollout_ref.actor._target_=recipe.sppo.config.SPPOActorConfig \ actor_rollout_ref.actor.sppo_eta=1.0 ``` ## Leverage the `extra` field Adding more keys to the `extra` field of any dataclass that inherits from `BaseConfig` also works. This way there's no need to define your own dataclass in python: ```yaml hydra: searchpath: - file://verl/trainer/config defaults: - ppo_trainer # base trainer config - _self_ # then apply these overrides actor_rollout_ref: actor: extra: sppo_eta: 1.0 ``` # Declaring mutable fields For historical reasons some fields in the configs are mutated inplace in the codebase such as batch size for data/sequence parallelism. We are in the process of deprecating this kind of behavior. However, if you want to intentionally mutate one field, specify it with the `_mutable_fields` attr: ```python @dataclass class CheckpointConfig(BaseConfig): """What to save/load and async behavior.""" _mutable_fields = BaseConfig._mutable_fields | {"save_contents"} # mark save_contents as mutable. save_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) load_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) async_save: bool = False ``` # Other helpful resources verl default trainer configs combines the following config files together, specified in the `_defaults_` field: https://github.com/volcengine/verl/blob/main/verl/trainer/config/ppo_trainer.yaml#L1-L36 - verl/trainer/config/ppo_trainer.yaml # main config for entrypoint - verl/trainer/config/actor/dp_actor.yaml - verl/trainer/config/critic/dp_critic.yaml - verl/trainer/config/reward_model/dp_reward_model.yaml - verl/trainer/config/rollout/rollout.yaml To quickly peek the default full config in a single file, you can check the auto-generated full config in https://github.com/volcengine/verl/blob/main/verl/trainer/config/_generated_ppo_trainer.yaml # Change log and impact on existing code This PR converts the following fields to structured dataclass in the training pipeline. More can be done in future PRs (contributions from the community is welcome) - [x] actor_rollout_ref.actor - [x] critic - [ ] actor_rollout_ref.rollout - [ ] actor_rollout_ref.ref - [ ] reward_model - [ ] data - [ ] trainer Changes needed for existing code that added new fields to config: - see recipe/sppo for an example - `OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))` now has to manually changed to `self.config.model.get("override_config", {})`. Because OmegaConf.to_container expects a DictConfig but config.model.override_config is already a dict. # Other Breaking Changes critic.optim.lr for megatron changed from 1e-6 to 1e-5 --------- Signed-off-by: ShareLer <ShareLe@163.com> Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joel <wuxibin@bytedance.com> Co-authored-by: Cheetah <1659275352@qq.com> Co-authored-by: 杨睿 <yangruipis@163.com> Co-authored-by: X. HU <huxiaobo@zju.edu.cn> Co-authored-by: Le Xue <48175490+ShareLer@users.noreply.github.com> Co-authored-by: Ziheng Jiang <ziheng@apache.org> Co-authored-by: Blue Space <57280232+ETOgaosion@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
7
.github/workflows/gpu_unit_tests.yml
vendored
7
.github/workflows/gpu_unit_tests.yml
vendored
@ -94,7 +94,10 @@ jobs:
|
||||
pip3 install cupy-cuda12x
|
||||
- name: Run all GPU unit tests
|
||||
run: |
|
||||
pytest -s -x --ignore-glob="*test_linear_cross_entropy_tp.py" --ignore-glob='*on_cpu.py' --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob='tests/special*' --ignore-glob="tests/experimental" tests/
|
||||
pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob='*on_cpu.py' --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob='tests/special*' --ignore-glob="tests/experimental" tests/
|
||||
- name: Testing LinearCrossEntropyTP Correctness, Computation Time and Memory Consumption
|
||||
run: |
|
||||
LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/utils/test_linear_cross_entropy_tp.py
|
||||
LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/utils/test_special_linear_cross_entropy_tp.py
|
||||
- name: Testing FSDP actor functionality
|
||||
run: |
|
||||
torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/actor/test_special_dp_actor.py
|
||||
|
@ -26,7 +26,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
+actor_rollout_ref.actor.fsdp_config.use_orig_params=True \
|
||||
actor_rollout_ref.actor.fsdp_config.use_orig_params=True \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
|
@ -34,7 +34,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=$OFFLOAD \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$OFFLOAD \
|
||||
+actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
|
||||
actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
|
@ -139,7 +139,7 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
def init_model(self):
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import_external_libs(self.config.model.get("external_lib", None))
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
|
||||
use_shm = self.config.model.get("use_shm", False)
|
||||
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
|
||||
|
@ -18,7 +18,7 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from verl.single_controller.base.decorator import Dispatch, register
|
||||
from verl.utils.debug import (
|
||||
@ -119,11 +119,9 @@ class RolloutWorker(ActorRolloutRefWorker):
|
||||
|
||||
importlib.import_module(self.config.model.external_lib)
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
override_transformer_config = {}
|
||||
self.param_dtype = torch.bfloat16
|
||||
self.dtype = PrecisionType.to_dtype(self.param_dtype)
|
||||
|
@ -41,7 +41,6 @@ reward_model:
|
||||
fsdp_config:
|
||||
min_num_params: 0
|
||||
param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}
|
||||
# grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload}
|
||||
optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}
|
||||
update: before # ``before`` for double-forward, ``after`` for single-forward
|
||||
optim:
|
||||
|
@ -17,6 +17,7 @@ import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from omegaconf import OmegaConf
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from verl import DataProto
|
||||
@ -98,9 +99,7 @@ class PRIMERewardModelWorker(Worker):
|
||||
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
|
||||
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
override_config_kwargs = {
|
||||
"bos_token_id": self.tokenizer.bos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
|
@ -22,7 +22,7 @@ import psutil
|
||||
import torch
|
||||
import torch.distributed
|
||||
from codetiming import Timer
|
||||
from omegaconf import open_dict
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
@ -83,10 +83,7 @@ class SPINRolloutRefWorker(ActorRolloutRefWorker):
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import_external_libs(self.config.model.get("external_lib", None))
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
use_remove_padding = self.config.model.get("use_remove_padding", False)
|
||||
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
|
||||
|
||||
|
22
recipe/sppo/config.py
Normal file
22
recipe/sppo/config.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from verl.workers.config import FSDPActorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class SPPOActorConfig(FSDPActorConfig):
|
||||
sppo_eta: float = 1.0
|
@ -10,7 +10,17 @@ defaults:
|
||||
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
_target_: recipe.sppo.config.SPPOActorConfig
|
||||
|
||||
# sppo_eta is an additional hyperparameter for SPPO, not available in
|
||||
# verl core. specifying _target_ with SPPOActorConfig is needed to
|
||||
# extend verl ActorConfig with custom fields.
|
||||
# additional, it is also possible to use the `extra` field natively supported
|
||||
# by all verl core dataclasses, without having to define SPPOActorConfig
|
||||
# extra:
|
||||
# sppo_eta: 1.0
|
||||
sppo_eta: 1.0
|
||||
|
||||
optim:
|
||||
lr_warmup_steps: 15
|
||||
rollout:
|
||||
|
@ -16,7 +16,7 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from omegaconf import open_dict
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
|
||||
from verl.single_controller.base.decorator import Dispatch, register
|
||||
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
|
||||
@ -43,10 +43,7 @@ class SPPOActorRolloutRefWorker(ActorRolloutRefWorker):
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import_external_libs(self.config.model.get("external_lib", None))
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
use_remove_padding = self.config.model.get("use_remove_padding", False)
|
||||
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
|
||||
|
||||
|
@ -33,7 +33,16 @@ def init_config() -> DictConfig:
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
||||
config = compose(config_name="ppo_trainer")
|
||||
config = compose(
|
||||
config_name="ppo_trainer",
|
||||
overrides=[
|
||||
"actor_rollout_ref.actor.use_dynamic_bsz=true",
|
||||
# test sleep/wake_up with fsdp offload
|
||||
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
|
||||
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
|
||||
],
|
||||
)
|
||||
|
||||
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
config.actor_rollout_ref.model.path = model_path
|
||||
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
||||
@ -43,10 +52,6 @@ def init_config() -> DictConfig:
|
||||
config.actor_rollout_ref.rollout.n = 4
|
||||
config.actor_rollout_ref.rollout.agent.num_workers = 2
|
||||
|
||||
# test sleep/wake_up with fsdp offload
|
||||
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
|
||||
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ python3 -m recipe.sppo.main_sppo \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=console \
|
||||
trainer.val_before_train=False \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.n_gpus_per_node=$NUM_GPUS \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.total_training_steps=1 \
|
||||
|
@ -62,6 +62,8 @@ def test_trainer_config_doc():
|
||||
"verl/trainer/config/ppo_trainer.yaml",
|
||||
"verl/trainer/config/actor/actor.yaml",
|
||||
"verl/trainer/config/actor/dp_actor.yaml",
|
||||
"verl/trainer/config/critic/critic.yaml",
|
||||
"verl/trainer/config/critic/dp_critic.yaml",
|
||||
"verl/trainer/config/ref/ref.yaml",
|
||||
"verl/trainer/config/ref/dp_ref.yaml",
|
||||
"verl/trainer/config/rollout/rollout.yaml",
|
||||
|
@ -258,7 +258,7 @@ actor_rollout_ref:
|
||||
calculate_log_probs: False
|
||||
# Nsight system profiler configs
|
||||
profiler:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
discrete: False
|
||||
all_ranks: False
|
||||
@ -336,7 +336,7 @@ critic:
|
||||
load_contents: ${critic.checkpoint.save_contents}
|
||||
# Nsight system profiler configs
|
||||
profiler:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
discrete: False
|
||||
all_ranks: False
|
||||
@ -379,7 +379,7 @@ reward_model:
|
||||
memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB
|
||||
# Nsight system profiler configs
|
||||
profiler:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
discrete: False
|
||||
all_ranks: False
|
||||
@ -390,7 +390,7 @@ custom_reward_function:
|
||||
name: compute_score
|
||||
|
||||
algorithm:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.AlgoConfig
|
||||
gamma: 1.0
|
||||
lam: 1.0
|
||||
@ -399,7 +399,7 @@ algorithm:
|
||||
use_kl_in_reward: False
|
||||
kl_penalty: kl # how to estimate kl divergence
|
||||
kl_ctrl:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.KLControlConfig
|
||||
type: fixed
|
||||
kl_coef: 0.001
|
||||
@ -407,8 +407,6 @@ algorithm:
|
||||
target_kl: 0.1
|
||||
use_pf_ppo: False
|
||||
pf_ppo:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
_target_: verl.trainer.config.PFPPOConfig
|
||||
reweight_method: pow # ["pow", "max_min", "max_random"]
|
||||
weight_pow: 2.0
|
||||
|
||||
|
@ -579,7 +579,7 @@ actor_rollout_ref:
|
||||
# profiler configs
|
||||
profiler:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
|
||||
# True for each task has its own database, False for all tasks in one training step share one database.
|
||||
@ -744,7 +744,7 @@ critic:
|
||||
# the corresponding dataclass is verl.utils.profiler.ProfilerConfig.
|
||||
profiler:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
|
||||
# True for each task has its own database, False for all tasks in one training step share one database.
|
||||
@ -858,7 +858,7 @@ reward_model:
|
||||
# profiler configs
|
||||
profiler:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
|
||||
# True for each task has its own database, False for all tasks in one training step share one database.
|
||||
@ -883,7 +883,7 @@ custom_reward_function:
|
||||
# config for the algorithm
|
||||
algorithm:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.AlgoConfig
|
||||
|
||||
# Discount factor for future rewards
|
||||
@ -907,7 +907,7 @@ algorithm:
|
||||
# KL control configuration
|
||||
kl_ctrl:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.KLControlConfig
|
||||
|
||||
# KL control type: "fixed" or "adaptive"
|
||||
@ -928,9 +928,6 @@ algorithm:
|
||||
# Preference feedback PPO settings
|
||||
pf_ppo:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
_target_: verl.trainer.config.PFPPOConfig
|
||||
|
||||
# Method for reweighting samples: "pow", "max_min", or "max_random"
|
||||
reweight_method: pow
|
||||
|
||||
|
@ -18,7 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.trainer.config import AlgoConfig, KLControlConfig, PFPPOConfig
|
||||
from verl.trainer.config import AlgoConfig, KLControlConfig
|
||||
from verl.trainer.ppo.core_algos import (
|
||||
compute_gae_advantage_return,
|
||||
compute_grpo_outcome_advantage,
|
||||
@ -49,7 +49,7 @@ class TestAlgoConfig(unittest.TestCase):
|
||||
"target_kl": 0.05,
|
||||
},
|
||||
"use_pf_ppo": True,
|
||||
"pf_ppo": {"_target_": "verl.trainer.config.PFPPOConfig", "reweight_method": "max_min", "weight_pow": 3.0},
|
||||
"pf_ppo": {"reweight_method": "max_min", "weight_pow": 3.0},
|
||||
}
|
||||
self.omega_config = OmegaConf.create(self.config_dict)
|
||||
|
||||
@ -86,9 +86,8 @@ class TestAlgoConfig(unittest.TestCase):
|
||||
self.assertEqual(config.kl_ctrl.target_kl, 0.05)
|
||||
|
||||
# Test PF PPO config
|
||||
self.assertIsInstance(config.pf_ppo, PFPPOConfig)
|
||||
self.assertEqual(config.pf_ppo.reweight_method, "max_min")
|
||||
self.assertEqual(config.pf_ppo.weight_pow, 3.0)
|
||||
self.assertEqual(config.pf_ppo.get("reweight_method"), "max_min")
|
||||
self.assertEqual(config.pf_ppo.get("weight_pow"), 3.0)
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that default values are properly set."""
|
||||
@ -123,7 +122,7 @@ class TestAlgoConfig(unittest.TestCase):
|
||||
# Check that nested configs are initialized
|
||||
self.assertIsNotNone(minimal_config.kl_ctrl)
|
||||
self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig)
|
||||
self.assertIsNone(minimal_config.pf_ppo)
|
||||
assert not minimal_config.pf_ppo
|
||||
|
||||
def test_config_init_from_yaml(self):
|
||||
import os
|
||||
@ -133,10 +132,9 @@ class TestAlgoConfig(unittest.TestCase):
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
||||
cfg = compose(config_name="ppo_trainer")
|
||||
algo_config = omega_conf_to_dataclass(cfg.algorithm)
|
||||
from verl.trainer.config import AlgoConfig, PFPPOConfig
|
||||
from verl.trainer.config import AlgoConfig
|
||||
|
||||
assert isinstance(algo_config, AlgoConfig)
|
||||
assert isinstance(algo_config.pf_ppo, PFPPOConfig)
|
||||
|
||||
|
||||
class TestAlgoCompute(unittest.TestCase):
|
||||
@ -153,7 +151,7 @@ class TestAlgoCompute(unittest.TestCase):
|
||||
kl_penalty="kl",
|
||||
kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05),
|
||||
use_pf_ppo=True,
|
||||
pf_ppo=PFPPOConfig(reweight_method="max_min", weight_pow=3.0),
|
||||
pf_ppo={"reweight_method": "max_min", "weight_pow": 3.0},
|
||||
)
|
||||
|
||||
def test_advantage_estimator_with_cfg(self):
|
||||
|
@ -1,170 +0,0 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
from verl.trainer.config.config import CriticConfig, FSDPCriticConfig, MegatronCriticConfig
|
||||
from verl.utils.config import omega_conf_to_dataclass
|
||||
|
||||
|
||||
class TestCriticConfig:
|
||||
"""Test suite for critic configuration dataclasses."""
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(self):
|
||||
"""Get the path to the config directory."""
|
||||
return Path(__file__).parent.parent.parent.parent / "verl" / "trainer" / "config" / "critic"
|
||||
|
||||
def test_megatron_critic_config_instantiation_from_yaml(self, config_dir):
|
||||
"""Test that MegatronCriticConfig can be instantiated from megatron_critic.yaml."""
|
||||
yaml_path = config_dir / "megatron_critic.yaml"
|
||||
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
|
||||
test_config = compose(config_name="megatron_critic")
|
||||
|
||||
megatron_config_obj = omega_conf_to_dataclass(test_config)
|
||||
|
||||
assert isinstance(megatron_config_obj, MegatronCriticConfig)
|
||||
assert isinstance(megatron_config_obj, CriticConfig)
|
||||
|
||||
expected_attrs = [
|
||||
"strategy",
|
||||
"rollout_n",
|
||||
"optim",
|
||||
"model",
|
||||
"ppo_mini_batch_size",
|
||||
"ppo_max_token_len_per_gpu",
|
||||
"cliprange_value",
|
||||
"get",
|
||||
"nccl_timeout",
|
||||
"megatron",
|
||||
"load_weight",
|
||||
]
|
||||
for attr in expected_attrs:
|
||||
assert hasattr(megatron_config_obj, attr), f"Missing attribute: {attr}"
|
||||
|
||||
assert callable(megatron_config_obj.get)
|
||||
assert megatron_config_obj.strategy == "megatron"
|
||||
|
||||
def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):
|
||||
"""Test that FSDPCriticConfig can be instantiated from dp_critic.yaml."""
|
||||
yaml_path = config_dir / "dp_critic.yaml"
|
||||
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
|
||||
test_config = compose(config_name="dp_critic")
|
||||
|
||||
fsdp_config_obj = omega_conf_to_dataclass(test_config)
|
||||
|
||||
assert isinstance(fsdp_config_obj, FSDPCriticConfig)
|
||||
assert isinstance(fsdp_config_obj, CriticConfig)
|
||||
|
||||
expected_attrs = [
|
||||
"strategy",
|
||||
"rollout_n",
|
||||
"optim",
|
||||
"model",
|
||||
"ppo_mini_batch_size",
|
||||
"ppo_max_token_len_per_gpu",
|
||||
"cliprange_value",
|
||||
"get",
|
||||
"forward_micro_batch_size",
|
||||
"forward_micro_batch_size_per_gpu",
|
||||
"ulysses_sequence_parallel_size",
|
||||
"grad_clip",
|
||||
]
|
||||
for attr in expected_attrs:
|
||||
assert hasattr(fsdp_config_obj, attr), f"Missing attribute: {attr}"
|
||||
|
||||
assert callable(fsdp_config_obj.get)
|
||||
assert fsdp_config_obj.strategy == "fsdp"
|
||||
|
||||
def test_config_inheritance_hierarchy(self):
|
||||
"""Test that the inheritance hierarchy is correct."""
|
||||
megatron_config = MegatronCriticConfig()
|
||||
assert isinstance(megatron_config, CriticConfig)
|
||||
assert isinstance(megatron_config, MegatronCriticConfig)
|
||||
|
||||
fsdp_config = FSDPCriticConfig()
|
||||
assert isinstance(fsdp_config, CriticConfig)
|
||||
assert isinstance(fsdp_config, FSDPCriticConfig)
|
||||
|
||||
critic_config = CriticConfig()
|
||||
assert isinstance(critic_config, CriticConfig)
|
||||
assert not isinstance(critic_config, MegatronCriticConfig)
|
||||
assert not isinstance(critic_config, FSDPCriticConfig)
|
||||
|
||||
def test_config_dict_interface(self):
|
||||
"""Test that configs provide dict-like interface from BaseConfig."""
|
||||
config = CriticConfig()
|
||||
|
||||
assert "strategy" in config
|
||||
assert config["strategy"] == "fsdp"
|
||||
|
||||
assert config.get("strategy") == "fsdp"
|
||||
assert config.get("nonexistent_key", "default") == "default"
|
||||
|
||||
keys = list(config)
|
||||
assert "strategy" in keys
|
||||
assert "rollout_n" in keys
|
||||
|
||||
assert len(config) > 0
|
||||
|
||||
def test_frozen_fields_immutability(self):
|
||||
"""Test that frozen fields raise exceptions when modified after creation."""
|
||||
critic_config = CriticConfig()
|
||||
frozen_fields = ["rollout_n", "strategy", "cliprange_value"]
|
||||
|
||||
for field_name in frozen_fields:
|
||||
with pytest.raises((AttributeError, TypeError, ValueError)):
|
||||
setattr(critic_config, field_name, "modified_value")
|
||||
|
||||
megatron_config = MegatronCriticConfig()
|
||||
megatron_frozen_fields = ["nccl_timeout", "load_weight", "data_loader_seed"]
|
||||
|
||||
for field_name in megatron_frozen_fields:
|
||||
with pytest.raises((AttributeError, TypeError, ValueError)):
|
||||
setattr(megatron_config, field_name, "modified_value")
|
||||
|
||||
fsdp_config = FSDPCriticConfig()
|
||||
fsdp_frozen_fields = ["ulysses_sequence_parallel_size", "grad_clip"]
|
||||
|
||||
for field_name in fsdp_frozen_fields:
|
||||
with pytest.raises((AttributeError, TypeError, ValueError)):
|
||||
setattr(fsdp_config, field_name, "modified_value")
|
||||
|
||||
def test_batch_size_fields_modifiable(self):
|
||||
"""Test that batch size fields can be modified after creation."""
|
||||
critic_config = CriticConfig()
|
||||
|
||||
critic_config.ppo_mini_batch_size = 8
|
||||
critic_config.ppo_micro_batch_size = 4
|
||||
critic_config.ppo_micro_batch_size_per_gpu = 2
|
||||
|
||||
assert critic_config.ppo_mini_batch_size == 8
|
||||
assert critic_config.ppo_micro_batch_size == 4
|
||||
assert critic_config.ppo_micro_batch_size_per_gpu == 2
|
||||
|
||||
fsdp_config = FSDPCriticConfig()
|
||||
|
||||
fsdp_config.forward_micro_batch_size = 16
|
||||
fsdp_config.forward_micro_batch_size_per_gpu = 8
|
||||
|
||||
assert fsdp_config.forward_micro_batch_size == 16
|
||||
assert fsdp_config.forward_micro_batch_size_per_gpu == 8
|
@ -20,6 +20,12 @@ from hydra import compose, initialize_config_dir
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
_BREAKING_CHANGES = [
|
||||
"critic.optim.lr", # mcore critic lr init value 1e-6 -> 1e-5
|
||||
"actor_rollout_ref.actor.optim.lr_warmup_steps", # None -> -1
|
||||
"critic.optim.lr_warmup_steps", # None -> -1
|
||||
]
|
||||
|
||||
|
||||
class TestConfigComparison(unittest.TestCase):
|
||||
"""Test that current configs match their legacy counterparts exactly."""
|
||||
@ -81,7 +87,7 @@ class TestConfigComparison(unittest.TestCase):
|
||||
)
|
||||
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):
|
||||
self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
|
||||
else:
|
||||
elif path not in _BREAKING_CHANGES:
|
||||
self.assertEqual(
|
||||
current_config,
|
||||
legacy_config,
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
@ -30,13 +30,16 @@ class TestDataclass:
|
||||
class TestTrainConfig:
|
||||
batch_size: int
|
||||
model: TestDataclass
|
||||
override_config: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
_cfg_str = """train_config:
|
||||
_target_: tests.utils.test_config_on_cpu.TestTrainConfig
|
||||
batch_size: 32
|
||||
model:
|
||||
hidden_size: 768
|
||||
activation: relu"""
|
||||
activation: relu
|
||||
override_config: {}"""
|
||||
|
||||
|
||||
class TestConfigOnCPU(unittest.TestCase):
|
||||
|
@ -57,7 +57,7 @@ class TestProfilerConfig(unittest.TestCase):
|
||||
from verl.utils.profiler.config import ProfilerConfig
|
||||
|
||||
# Create a new ProfilerConfig instance
|
||||
config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0])
|
||||
config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0], extra={"key": "value"})
|
||||
|
||||
# Test direct attribute assignment
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
@ -79,7 +79,9 @@ class TestProfilerConfig(unittest.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
config["ranks"] = [1, 2, 3]
|
||||
|
||||
config["extra"]["key"] = "value"
|
||||
assert config["extra"]["key"] == "value"
|
||||
config["extra"]["key"] = "value2"
|
||||
assert config["extra"]["key"] == "value2"
|
||||
|
||||
|
||||
class TestNsightSystemsProfiler(unittest.TestCase):
|
||||
|
294
tests/workers/actor/test_special_dp_actor.py
Normal file
294
tests/workers/actor/test_special_dp_actor.py
Normal file
@ -0,0 +1,294 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tensordict import TensorDict
|
||||
from transformers import AutoModelForCausalLM, Qwen3Config
|
||||
|
||||
from verl import DataProto
|
||||
from verl.workers.actor.dp_actor import DataParallelPPOActor
|
||||
from verl.workers.config import FSDPActorConfig, OptimizerConfig
|
||||
|
||||
|
||||
class MockTransformerModel(nn.Module):
|
||||
"""Mock transformer model for testing DataParallelPPOActor"""
|
||||
|
||||
def __init__(self, vocab_size=1000, hidden_size=64):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4, batch_first=True), num_layers=2
|
||||
)
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, position_ids=None, use_cache=False, **kwargs):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
embeddings = self.embedding(input_ids)
|
||||
hidden_states = self.transformer(embeddings)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
class MockOutput:
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
|
||||
return MockOutput(logits)
|
||||
|
||||
|
||||
class TestDataParallelPPOActor(unittest.TestCase):
|
||||
"""Test DataParallelPPOActor compute_log_prob and update_policy methods"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
import os
|
||||
|
||||
import torch.distributed
|
||||
|
||||
backend = "cpu:gloo,cuda:nccl" if torch.cuda.is_available() else "gloo"
|
||||
if not torch.distributed.is_initialized():
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
init_method=os.environ.get("DIST_INIT_METHOD", "env://"),
|
||||
)
|
||||
|
||||
self.mock_memory_info_patcher = patch("verl.utils.profiler.performance._get_current_mem_info")
|
||||
self.mock_memory_info = self.mock_memory_info_patcher.start()
|
||||
self.mock_memory_info.return_value = ("0.00", "0.00", "0.00", "0.00")
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.config = FSDPActorConfig(
|
||||
strategy="fsdp",
|
||||
ppo_mini_batch_size=4,
|
||||
ppo_micro_batch_size_per_gpu=2,
|
||||
ppo_epochs=1,
|
||||
clip_ratio=0.2,
|
||||
entropy_coeff=0.01,
|
||||
grad_clip=1.0,
|
||||
use_dynamic_bsz=False,
|
||||
use_torch_compile=False, # Disable torch.compile for testing
|
||||
ulysses_sequence_parallel_size=1,
|
||||
optim=OptimizerConfig(lr=1e-6),
|
||||
)
|
||||
|
||||
self.mock_model = MockTransformerModel(vocab_size=1000, hidden_size=64).to(self.device)
|
||||
self.mock_optimizer = torch.optim.Adam(self.mock_model.parameters(), lr=1e-4)
|
||||
|
||||
self.actor = DataParallelPPOActor(
|
||||
config=self.config, actor_module=self.mock_model, actor_optimizer=self.mock_optimizer
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
# repated init and destroy seems unstable with torch
|
||||
# torch.distributed.destroy_process_group()
|
||||
self.mock_memory_info_patcher.stop()
|
||||
|
||||
def _create_test_data_for_compute_log_prob(self):
|
||||
"""Create test DataProto for compute_log_prob method"""
|
||||
batch_size = 2
|
||||
prompt_length = 8
|
||||
response_length = 4
|
||||
total_length = prompt_length + response_length
|
||||
vocab_size = 1000
|
||||
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, total_length)).to(self.device)
|
||||
attention_mask = torch.ones(batch_size, total_length).to(self.device)
|
||||
position_ids = torch.arange(total_length).unsqueeze(0).expand(batch_size, -1).to(self.device)
|
||||
responses = input_ids[:, -response_length:] # Last part is the response
|
||||
|
||||
tensor_dict = TensorDict(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"responses": responses,
|
||||
},
|
||||
batch_size=[batch_size],
|
||||
)
|
||||
|
||||
meta_info = {"micro_batch_size": batch_size, "temperature": 1.0, "use_dynamic_bsz": False}
|
||||
|
||||
return DataProto(batch=tensor_dict, meta_info=meta_info)
|
||||
|
||||
def _create_test_data_for_update_policy(self):
|
||||
"""Create test DataProto for update_policy method"""
|
||||
batch_size = 4 # Must match ppo_mini_batch_size
|
||||
prompt_length = 8
|
||||
response_length = 4
|
||||
total_length = prompt_length + response_length
|
||||
vocab_size = 1000
|
||||
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, total_length)).to(self.device)
|
||||
attention_mask = torch.ones(batch_size, total_length).to(self.device)
|
||||
position_ids = torch.arange(total_length).unsqueeze(0).expand(batch_size, -1).to(self.device)
|
||||
responses = input_ids[:, -response_length:]
|
||||
response_mask = torch.ones(batch_size, response_length).to(self.device)
|
||||
old_log_probs = torch.randn(batch_size, response_length).to(self.device) * 0.1 # Small values
|
||||
advantages = torch.randn(batch_size, response_length).to(self.device) * 0.5
|
||||
|
||||
tensor_dict = TensorDict(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"responses": responses,
|
||||
"response_mask": response_mask,
|
||||
"old_log_probs": old_log_probs,
|
||||
"advantages": advantages,
|
||||
},
|
||||
batch_size=[batch_size],
|
||||
)
|
||||
|
||||
meta_info = {"temperature": 1.0}
|
||||
|
||||
return DataProto(batch=tensor_dict, meta_info=meta_info)
|
||||
|
||||
def test_compute_log_prob(self):
|
||||
"""Test compute_log_prob method"""
|
||||
data = self._create_test_data_for_compute_log_prob()
|
||||
|
||||
log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=True)
|
||||
|
||||
batch_size = data.batch["responses"].shape[0]
|
||||
response_length = data.batch["responses"].shape[1]
|
||||
|
||||
self.assertIsInstance(log_probs, torch.Tensor)
|
||||
self.assertEqual(log_probs.shape, (batch_size, response_length))
|
||||
self.assertTrue(torch.all(torch.isfinite(log_probs)))
|
||||
|
||||
self.assertIsInstance(entropies, torch.Tensor)
|
||||
self.assertEqual(entropies.shape, (batch_size, response_length))
|
||||
self.assertTrue(torch.all(torch.isfinite(entropies)))
|
||||
self.assertTrue(torch.all(entropies >= 0)) # Entropy should be non-negative
|
||||
|
||||
def test_compute_log_prob_without_entropy(self):
|
||||
"""Test compute_log_prob method without entropy calculation"""
|
||||
data = self._create_test_data_for_compute_log_prob()
|
||||
|
||||
log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=False)
|
||||
|
||||
batch_size = data.batch["responses"].shape[0]
|
||||
response_length = data.batch["responses"].shape[1]
|
||||
|
||||
self.assertIsInstance(log_probs, torch.Tensor)
|
||||
self.assertEqual(log_probs.shape, (batch_size, response_length))
|
||||
self.assertTrue(torch.all(torch.isfinite(log_probs)))
|
||||
|
||||
self.assertIsNone(entropies)
|
||||
|
||||
def test_update_policy(self):
|
||||
"""Test update_policy method"""
|
||||
data = self._create_test_data_for_update_policy()
|
||||
|
||||
metrics = self.actor.update_policy(data)
|
||||
|
||||
self.assertIsInstance(metrics, dict)
|
||||
|
||||
expected_metric_keys = [
|
||||
"actor/pg_loss",
|
||||
"actor/pg_clipfrac",
|
||||
"actor/ppo_kl",
|
||||
"actor/pg_clipfrac_lower",
|
||||
"actor/grad_norm",
|
||||
]
|
||||
|
||||
for key in expected_metric_keys:
|
||||
self.assertIn(key, metrics)
|
||||
if isinstance(metrics[key], list):
|
||||
self.assertTrue(all(torch.isfinite(torch.tensor(v)) for v in metrics[key]))
|
||||
else:
|
||||
self.assertIsInstance(metrics[key], (float, int))
|
||||
self.assertTrue(torch.isfinite(torch.tensor(metrics[key])))
|
||||
|
||||
def test_dataparallelppoactor_initialization(self):
|
||||
"""Test DataParallelPPOActor initialization"""
|
||||
self.assertIsNotNone(self.actor.actor_module)
|
||||
self.assertIsNotNone(self.actor.actor_optimizer)
|
||||
self.assertEqual(self.actor.config, self.config)
|
||||
|
||||
self.assertEqual(self.actor.config.strategy, "fsdp")
|
||||
self.assertEqual(self.actor.config.ppo_mini_batch_size, 4)
|
||||
self.assertEqual(self.actor.config.clip_ratio, 0.2)
|
||||
|
||||
def test_dataparallelppoactor_with_qwen3_model(self):
|
||||
"""Test DataParallelPPOActor with real Qwen3ForCausalLM model"""
|
||||
qwen_config = Qwen3Config(
|
||||
vocab_size=1000,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
max_position_embeddings=512,
|
||||
torch_dtype=torch.float32,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
with torch.device(self.device):
|
||||
qwen_model = AutoModelForCausalLM.from_config(config=qwen_config, torch_dtype=torch.float32).to(self.device)
|
||||
|
||||
qwen_optimizer = torch.optim.Adam(qwen_model.parameters(), lr=1e-4)
|
||||
|
||||
qwen_actor = DataParallelPPOActor(config=self.config, actor_module=qwen_model, actor_optimizer=qwen_optimizer)
|
||||
|
||||
data = self._create_test_data_for_compute_log_prob()
|
||||
log_probs, entropies = qwen_actor.compute_log_prob(data, calculate_entropy=True)
|
||||
|
||||
batch_size = data.batch["responses"].shape[0]
|
||||
response_length = data.batch["responses"].shape[1]
|
||||
|
||||
self.assertIsInstance(log_probs, torch.Tensor)
|
||||
self.assertEqual(log_probs.shape, (batch_size, response_length))
|
||||
self.assertTrue(torch.all(torch.isfinite(log_probs)))
|
||||
|
||||
self.assertIsInstance(entropies, torch.Tensor)
|
||||
self.assertEqual(entropies.shape, (batch_size, response_length))
|
||||
self.assertTrue(torch.all(torch.isfinite(entropies)))
|
||||
self.assertTrue(torch.all(entropies >= 0))
|
||||
|
||||
policy_data = self._create_test_data_for_update_policy()
|
||||
metrics = qwen_actor.update_policy(policy_data)
|
||||
|
||||
self.assertIsInstance(metrics, dict)
|
||||
|
||||
expected_metric_keys = [
|
||||
"actor/pg_loss",
|
||||
"actor/pg_clipfrac",
|
||||
"actor/ppo_kl",
|
||||
"actor/pg_clipfrac_lower",
|
||||
"actor/grad_norm",
|
||||
]
|
||||
|
||||
for key in expected_metric_keys:
|
||||
self.assertIn(key, metrics)
|
||||
if isinstance(metrics[key], list):
|
||||
self.assertTrue(all(torch.isfinite(torch.tensor(v)) for v in metrics[key]))
|
||||
else:
|
||||
self.assertIsInstance(metrics[key], (float, int))
|
||||
self.assertTrue(torch.isfinite(torch.tensor(metrics[key])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
240
tests/workers/config/test_actor_config_on_cpu.py
Normal file
240
tests/workers/config/test_actor_config_on_cpu.py
Normal file
@ -0,0 +1,240 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from verl.utils.config import omega_conf_to_dataclass
|
||||
from verl.workers.config import ActorConfig, FSDPActorConfig, McoreActorConfig, OptimizerConfig
|
||||
|
||||
|
||||
class TestActorConfig(unittest.TestCase):
|
||||
"""Test the ActorConfig dataclass and its variants."""
|
||||
|
||||
def test_config_inheritance(self):
|
||||
"""Test that the inheritance hierarchy works correctly."""
|
||||
megatron_dict = {
|
||||
"_target_": "verl.workers.config.McoreActorConfig",
|
||||
"strategy": "megatron",
|
||||
"ppo_mini_batch_size": 256,
|
||||
"ppo_micro_batch_size_per_gpu": 256,
|
||||
"clip_ratio": 0.2,
|
||||
"optim": {
|
||||
"_target_": "verl.workers.config.OptimizerConfig",
|
||||
"lr": 0.1,
|
||||
},
|
||||
}
|
||||
fsdp_dict = {
|
||||
"_target_": "verl.workers.config.FSDPActorConfig",
|
||||
"strategy": "fsdp",
|
||||
"ppo_mini_batch_size": 256,
|
||||
"ppo_micro_batch_size_per_gpu": 256,
|
||||
"clip_ratio": 0.2,
|
||||
"optim": {
|
||||
"_target_": "verl.workers.config.OptimizerConfig",
|
||||
"lr": 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
megatron_config = omega_conf_to_dataclass(megatron_dict)
|
||||
fsdp_config = omega_conf_to_dataclass(fsdp_dict)
|
||||
|
||||
self.assertIsInstance(megatron_config, ActorConfig)
|
||||
self.assertIsInstance(fsdp_config, ActorConfig)
|
||||
|
||||
self.assertEqual(megatron_config.ppo_mini_batch_size, fsdp_config.ppo_mini_batch_size)
|
||||
self.assertEqual(megatron_config.clip_ratio, fsdp_config.clip_ratio)
|
||||
|
||||
def test_actor_config_from_yaml(self):
|
||||
"""Test creating ActorConfig from YAML file."""
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")):
|
||||
cfg = compose(config_name="actor", overrides=["strategy=fsdp", "ppo_micro_batch_size_per_gpu=128"])
|
||||
|
||||
config = omega_conf_to_dataclass(cfg)
|
||||
|
||||
self.assertIsInstance(config, ActorConfig)
|
||||
self.assertEqual(config.strategy, "fsdp")
|
||||
|
||||
def test_fsdp_actor_config_from_yaml(self):
|
||||
"""Test creating FSDPActorConfig from YAML file."""
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")):
|
||||
cfg = compose(config_name="dp_actor", overrides=["strategy=fsdp2", "ppo_micro_batch_size_per_gpu=128"])
|
||||
|
||||
config = omega_conf_to_dataclass(cfg)
|
||||
|
||||
self.assertIsInstance(config, FSDPActorConfig)
|
||||
self.assertEqual(config.strategy, "fsdp2")
|
||||
|
||||
def test_megatron_actor_config_from_yaml(self):
|
||||
"""Test creating McoreActorConfig from YAML file."""
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")):
|
||||
cfg = compose(config_name="megatron_actor", overrides=["ppo_micro_batch_size_per_gpu=128"])
|
||||
|
||||
config = omega_conf_to_dataclass(cfg)
|
||||
|
||||
self.assertIsInstance(config, McoreActorConfig)
|
||||
self.assertEqual(config.strategy, "megatron")
|
||||
|
||||
def test_config_get_method(self):
|
||||
"""Test the get method for backward compatibility."""
|
||||
config_dict = {
|
||||
"_target_": "verl.workers.config.ActorConfig",
|
||||
"strategy": "fsdp",
|
||||
"ppo_mini_batch_size": 256,
|
||||
"ppo_micro_batch_size_per_gpu": 256,
|
||||
"optim": {
|
||||
"_target_": "verl.workers.config.OptimizerConfig",
|
||||
"lr": 0.1,
|
||||
},
|
||||
}
|
||||
config = omega_conf_to_dataclass(config_dict)
|
||||
|
||||
self.assertEqual(config.get("strategy"), "fsdp")
|
||||
self.assertEqual(config.get("ppo_mini_batch_size"), 256)
|
||||
|
||||
self.assertIsNone(config.get("non_existing"))
|
||||
self.assertEqual(config.get("non_existing", "default"), "default")
|
||||
|
||||
def test_config_dict_like_access(self):
|
||||
"""Test dictionary-like access to config fields."""
|
||||
config_dict = {
|
||||
"_target_": "verl.workers.config.ActorConfig",
|
||||
"strategy": "fsdp",
|
||||
"ppo_mini_batch_size": 256,
|
||||
"ppo_micro_batch_size_per_gpu": 256,
|
||||
"optim": {
|
||||
"_target_": "verl.workers.config.OptimizerConfig",
|
||||
"lr": 0.1,
|
||||
},
|
||||
}
|
||||
config = omega_conf_to_dataclass(config_dict)
|
||||
|
||||
self.assertEqual(config["strategy"], "fsdp")
|
||||
self.assertEqual(config["ppo_mini_batch_size"], 256)
|
||||
|
||||
field_names = list(config)
|
||||
self.assertIn("strategy", field_names)
|
||||
self.assertIn("ppo_mini_batch_size", field_names)
|
||||
|
||||
self.assertGreater(len(config), 0)
|
||||
|
||||
def test_frozen_fields_modification_raises_exception(self):
|
||||
"""Test that modifying frozen fields raises an exception."""
|
||||
config_dict = {
|
||||
"_target_": "verl.workers.config.ActorConfig",
|
||||
"strategy": "fsdp",
|
||||
"ppo_mini_batch_size": 256,
|
||||
"ppo_micro_batch_size_per_gpu": 256,
|
||||
"optim": {
|
||||
"_target_": "verl.workers.config.OptimizerConfig",
|
||||
"lr": 0.1,
|
||||
},
|
||||
}
|
||||
config = omega_conf_to_dataclass(config_dict)
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
config.strategy = "megatron"
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
config.clip_ratio = 0.5
|
||||
|
||||
config.ppo_mini_batch_size = 512 # This should work since it's not in frozen fields anymore
|
||||
self.assertEqual(config.ppo_mini_batch_size, 512)
|
||||
|
||||
def test_actor_config_validation_exceptions(self):
|
||||
"""Test that ActorConfig.__post_init__ raises appropriate validation exceptions."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
with self.assertRaises((ValueError, AssertionError)) as cm:
|
||||
ActorConfig(
|
||||
strategy="fsdp",
|
||||
loss_agg_mode="invalid-mode",
|
||||
use_dynamic_bsz=True,
|
||||
optim=optim,
|
||||
ppo_micro_batch_size_per_gpu=4,
|
||||
)
|
||||
self.assertIn("Invalid loss_agg_mode", str(cm.exception))
|
||||
|
||||
with self.assertRaises((ValueError, AssertionError)) as cm:
|
||||
ActorConfig(
|
||||
strategy="fsdp",
|
||||
use_dynamic_bsz=False,
|
||||
ppo_micro_batch_size=4,
|
||||
ppo_micro_batch_size_per_gpu=2,
|
||||
optim=optim,
|
||||
)
|
||||
self.assertIn("You have set both", str(cm.exception))
|
||||
|
||||
with self.assertRaises((ValueError, AssertionError)) as cm:
|
||||
ActorConfig(
|
||||
strategy="fsdp",
|
||||
use_dynamic_bsz=False,
|
||||
ppo_micro_batch_size=None,
|
||||
ppo_micro_batch_size_per_gpu=None,
|
||||
optim=optim,
|
||||
)
|
||||
self.assertIn("Please set at least one", str(cm.exception))
|
||||
|
||||
config = ActorConfig(
|
||||
strategy="fsdp",
|
||||
use_dynamic_bsz=True,
|
||||
ppo_micro_batch_size=None,
|
||||
ppo_micro_batch_size_per_gpu=None,
|
||||
optim=optim,
|
||||
)
|
||||
self.assertIsNotNone(config) # Should not raise an exception
|
||||
|
||||
def test_fsdp_actor_config_validation_exceptions(self):
|
||||
"""Test that FSDPActorConfig.validate() raises appropriate validation exceptions."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
config = FSDPActorConfig(
|
||||
strategy="fsdp",
|
||||
ulysses_sequence_parallel_size=2,
|
||||
use_dynamic_bsz=True, # Skip batch size validation to focus on FSDP validation
|
||||
optim=optim,
|
||||
)
|
||||
|
||||
model_config = {"use_remove_padding": False}
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
config.validate(n_gpus=8, train_batch_size=256, model_config=model_config)
|
||||
self.assertIn("you must enable `use_remove_padding`", str(cm.exception))
|
||||
|
||||
def test_actor_config_validate_method_exceptions(self):
|
||||
"""Test that ActorConfig.validate() raises appropriate validation exceptions."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
config = ActorConfig(
|
||||
strategy="fsdp",
|
||||
use_dynamic_bsz=False,
|
||||
ppo_mini_batch_size=256,
|
||||
ppo_micro_batch_size=8,
|
||||
ppo_micro_batch_size_per_gpu=None, # Ensure only one batch size setting is used
|
||||
optim=optim,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
config.validate(n_gpus=8, train_batch_size=128)
|
||||
self.assertIn("train_batch_size", str(cm.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
config.validate(n_gpus=16, train_batch_size=512)
|
||||
self.assertIn("must be >= n_gpus", str(cm.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
308
tests/workers/config/test_critic_config_on_cpu.py
Normal file
308
tests/workers/config/test_critic_config_on_cpu.py
Normal file
@ -0,0 +1,308 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
from verl.utils.config import omega_conf_to_dataclass
|
||||
from verl.utils.profiler import ProfilerConfig
|
||||
from verl.workers.config import (
|
||||
CriticConfig,
|
||||
FSDPCriticConfig,
|
||||
McoreCriticConfig,
|
||||
OptimizerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestCriticConfig:
|
||||
"""Test suite for critic configuration dataclasses."""
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(self):
|
||||
"""Get the path to the config directory."""
|
||||
return Path(__file__).parent.parent.parent.parent / "verl" / "trainer" / "config" / "critic"
|
||||
|
||||
def test_megatron_critic_config_instantiation_from_yaml(self, config_dir):
|
||||
"""Test that McoreCriticConfig can be instantiated from megatron_critic.yaml."""
|
||||
yaml_path = config_dir / "megatron_critic.yaml"
|
||||
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
|
||||
test_config = compose(config_name="megatron_critic", overrides=["ppo_micro_batch_size_per_gpu=1"])
|
||||
|
||||
megatron_config_obj = omega_conf_to_dataclass(test_config)
|
||||
|
||||
assert isinstance(megatron_config_obj, McoreCriticConfig)
|
||||
assert isinstance(megatron_config_obj, CriticConfig)
|
||||
|
||||
expected_attrs = [
|
||||
"strategy",
|
||||
"rollout_n",
|
||||
"optim",
|
||||
"model",
|
||||
"ppo_mini_batch_size",
|
||||
"ppo_max_token_len_per_gpu",
|
||||
"cliprange_value",
|
||||
"get",
|
||||
"nccl_timeout",
|
||||
"megatron",
|
||||
"load_weight",
|
||||
]
|
||||
for attr in expected_attrs:
|
||||
assert hasattr(megatron_config_obj, attr), f"Missing attribute: {attr}"
|
||||
|
||||
assert callable(megatron_config_obj.get)
|
||||
assert megatron_config_obj.strategy == "megatron"
|
||||
|
||||
def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):
|
||||
"""Test that FSDPCriticConfig can be instantiated from dp_critic.yaml."""
|
||||
yaml_path = config_dir / "dp_critic.yaml"
|
||||
assert yaml_path.exists(), f"Config file not found: {yaml_path}"
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/critic")):
|
||||
test_config = compose(config_name="dp_critic", overrides=["ppo_micro_batch_size_per_gpu=1"])
|
||||
|
||||
fsdp_config_obj = omega_conf_to_dataclass(test_config)
|
||||
|
||||
assert isinstance(fsdp_config_obj, FSDPCriticConfig)
|
||||
assert isinstance(fsdp_config_obj, CriticConfig)
|
||||
|
||||
expected_attrs = [
|
||||
"strategy",
|
||||
"rollout_n",
|
||||
"optim",
|
||||
"model",
|
||||
"ppo_mini_batch_size",
|
||||
"ppo_max_token_len_per_gpu",
|
||||
"cliprange_value",
|
||||
"get",
|
||||
"forward_micro_batch_size",
|
||||
"forward_micro_batch_size_per_gpu",
|
||||
"ulysses_sequence_parallel_size",
|
||||
"grad_clip",
|
||||
]
|
||||
for attr in expected_attrs:
|
||||
assert hasattr(fsdp_config_obj, attr), f"Missing attribute: {attr}"
|
||||
|
||||
assert callable(fsdp_config_obj.get)
|
||||
assert fsdp_config_obj.strategy == "fsdp"
|
||||
|
||||
def test_config_inheritance_hierarchy(self):
|
||||
"""Test that the inheritance hierarchy is correct."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
|
||||
assert isinstance(megatron_config, CriticConfig)
|
||||
assert isinstance(megatron_config, McoreCriticConfig)
|
||||
|
||||
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
|
||||
assert isinstance(fsdp_config, CriticConfig)
|
||||
assert isinstance(fsdp_config, FSDPCriticConfig)
|
||||
|
||||
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
|
||||
assert isinstance(critic_config, CriticConfig)
|
||||
assert not isinstance(critic_config, McoreCriticConfig)
|
||||
assert not isinstance(critic_config, FSDPCriticConfig)
|
||||
|
||||
def test_config_dict_interface(self):
|
||||
"""Test that configs provide dict-like interface from BaseConfig."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
|
||||
|
||||
assert "strategy" in config
|
||||
assert config["strategy"] == "fsdp2"
|
||||
|
||||
assert config.get("strategy") == "fsdp2"
|
||||
assert config.get("nonexistent_key", "default") == "default"
|
||||
|
||||
keys = list(config)
|
||||
assert "strategy" in keys
|
||||
assert "rollout_n" in keys
|
||||
|
||||
assert len(config) > 0
|
||||
|
||||
def test_frozen_fields_immutability(self):
|
||||
"""Test that frozen fields raise exceptions when modified after creation."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
|
||||
frozen_fields = ["rollout_n", "strategy", "cliprange_value"]
|
||||
|
||||
for field_name in frozen_fields:
|
||||
with pytest.raises((AttributeError, TypeError, ValueError)):
|
||||
setattr(critic_config, field_name, "modified_value")
|
||||
|
||||
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
|
||||
megatron_frozen_fields = ["nccl_timeout", "load_weight", "data_loader_seed"]
|
||||
|
||||
for field_name in megatron_frozen_fields:
|
||||
with pytest.raises((AttributeError, TypeError, ValueError)):
|
||||
setattr(megatron_config, field_name, "modified_value")
|
||||
|
||||
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
|
||||
fsdp_frozen_fields = ["ulysses_sequence_parallel_size", "grad_clip"]
|
||||
|
||||
for field_name in fsdp_frozen_fields:
|
||||
with pytest.raises((AttributeError, TypeError, ValueError)):
|
||||
setattr(fsdp_config, field_name, "modified_value")
|
||||
|
||||
def test_batch_size_fields_modifiable(self):
|
||||
"""Test that batch size fields can be modified after creation."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
|
||||
|
||||
critic_config.ppo_mini_batch_size = 8
|
||||
critic_config.ppo_micro_batch_size = 4
|
||||
critic_config.ppo_micro_batch_size_per_gpu = 2
|
||||
|
||||
assert critic_config.ppo_mini_batch_size == 8
|
||||
assert critic_config.ppo_micro_batch_size == 4
|
||||
assert critic_config.ppo_micro_batch_size_per_gpu == 2
|
||||
|
||||
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
|
||||
|
||||
fsdp_config.forward_micro_batch_size = 16
|
||||
fsdp_config.forward_micro_batch_size_per_gpu = 8
|
||||
|
||||
assert fsdp_config.forward_micro_batch_size == 16
|
||||
assert fsdp_config.forward_micro_batch_size_per_gpu == 8
|
||||
|
||||
def test_profiler_config_type_validation(self):
|
||||
"""Test that profiler field has correct type and validation."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
|
||||
assert isinstance(critic_config.profiler, ProfilerConfig)
|
||||
assert critic_config.profiler.discrete is False
|
||||
assert critic_config.profiler.all_ranks is False
|
||||
assert critic_config.profiler.ranks == []
|
||||
|
||||
custom_profiler = ProfilerConfig(discrete=True, all_ranks=True, ranks=[0, 1])
|
||||
critic_config_custom = CriticConfig(
|
||||
profiler=custom_profiler, ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim
|
||||
)
|
||||
assert isinstance(critic_config_custom.profiler, ProfilerConfig)
|
||||
assert critic_config_custom.profiler.discrete is True
|
||||
assert critic_config_custom.profiler.all_ranks is True
|
||||
assert critic_config_custom.profiler.ranks == [0, 1]
|
||||
|
||||
profiler1 = ProfilerConfig(discrete=True, ranks=[0, 1])
|
||||
profiler2 = ProfilerConfig(all_ranks=True, ranks=[1, 2])
|
||||
|
||||
union_result = profiler1.union(profiler2)
|
||||
assert union_result.discrete is True
|
||||
assert union_result.all_ranks is True
|
||||
assert set(union_result.ranks) == {0, 1, 2}
|
||||
|
||||
intersect_result = profiler1.intersect(profiler2)
|
||||
assert intersect_result.discrete is False
|
||||
assert intersect_result.all_ranks is False
|
||||
assert intersect_result.ranks == [1]
|
||||
|
||||
def test_critic_config_validation_logic(self):
|
||||
"""Test the __post_init__ validation logic for CriticConfig."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
valid_config = CriticConfig(
|
||||
strategy="fsdp2", ppo_micro_batch_size_per_gpu=2, use_dynamic_bsz=False, optim=optim
|
||||
)
|
||||
assert valid_config.ppo_micro_batch_size_per_gpu == 2
|
||||
|
||||
valid_config2 = CriticConfig(
|
||||
strategy="fsdp2",
|
||||
ppo_micro_batch_size_per_gpu=None,
|
||||
ppo_micro_batch_size=4,
|
||||
ppo_mini_batch_size=8,
|
||||
use_dynamic_bsz=False,
|
||||
optim=optim,
|
||||
)
|
||||
assert valid_config2.ppo_micro_batch_size == 4
|
||||
|
||||
dynamic_config = CriticConfig(
|
||||
strategy="fsdp2", ppo_micro_batch_size_per_gpu=2, use_dynamic_bsz=True, optim=optim
|
||||
)
|
||||
assert dynamic_config.use_dynamic_bsz is True
|
||||
|
||||
with pytest.raises(ValueError, match="You have set both.*micro_batch_size.*AND.*micro_batch_size_per_gpu"):
|
||||
CriticConfig(
|
||||
strategy="fsdp2",
|
||||
ppo_micro_batch_size=4,
|
||||
ppo_micro_batch_size_per_gpu=2,
|
||||
use_dynamic_bsz=False,
|
||||
optim=optim,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Please set at least one of.*micro_batch_size.*or.*micro_batch_size_per_gpu"
|
||||
):
|
||||
CriticConfig(
|
||||
strategy="fsdp2",
|
||||
ppo_micro_batch_size=None,
|
||||
ppo_micro_batch_size_per_gpu=None,
|
||||
use_dynamic_bsz=False,
|
||||
optim=optim,
|
||||
)
|
||||
|
||||
def test_micro_batch_size_divisibility_validation(self):
|
||||
"""Test micro batch size divisibility validation in __post_init__."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
valid_config = CriticConfig(
|
||||
strategy="fsdp2", ppo_micro_batch_size_per_gpu=2, ppo_mini_batch_size=8, use_dynamic_bsz=False, optim=optim
|
||||
)
|
||||
assert valid_config.ppo_mini_batch_size == 8
|
||||
assert valid_config.ppo_micro_batch_size_per_gpu == 2
|
||||
|
||||
valid_config_with_mbs = CriticConfig(
|
||||
strategy="fsdp2", ppo_mini_batch_size=8, ppo_micro_batch_size=4, use_dynamic_bsz=False, optim=optim
|
||||
)
|
||||
assert valid_config_with_mbs.ppo_mini_batch_size == 8
|
||||
assert valid_config_with_mbs.ppo_micro_batch_size == 4
|
||||
|
||||
with pytest.raises(ValueError, match="ppo_mini_batch_size.*must be divisible by.*ppo_micro_batch_size"):
|
||||
CriticConfig(
|
||||
strategy="fsdp2", ppo_mini_batch_size=7, ppo_micro_batch_size=4, use_dynamic_bsz=False, optim=optim
|
||||
)
|
||||
|
||||
dynamic_config = CriticConfig(
|
||||
strategy="fsdp2", ppo_mini_batch_size=7, ppo_micro_batch_size=4, use_dynamic_bsz=True, optim=optim
|
||||
)
|
||||
assert dynamic_config.use_dynamic_bsz is True
|
||||
|
||||
def test_fsdp_sequence_parallelism_validation(self):
|
||||
"""Test FSDP sequence parallelism validation in FSDPCriticConfig.__post_init__."""
|
||||
optim = OptimizerConfig(lr=0.1)
|
||||
valid_config = FSDPCriticConfig(
|
||||
ppo_micro_batch_size_per_gpu=2,
|
||||
ulysses_sequence_parallel_size=2,
|
||||
model={"use_remove_padding": True},
|
||||
optim=optim,
|
||||
)
|
||||
assert valid_config.ulysses_sequence_parallel_size == 2
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="When using sequence parallelism for critic, you must enable.*use_remove_padding"
|
||||
):
|
||||
FSDPCriticConfig(
|
||||
ppo_micro_batch_size_per_gpu=2,
|
||||
ulysses_sequence_parallel_size=2,
|
||||
model={"use_remove_padding": False},
|
||||
optim=optim,
|
||||
)
|
||||
|
||||
valid_config_no_sp = FSDPCriticConfig(
|
||||
ppo_micro_batch_size_per_gpu=2,
|
||||
ulysses_sequence_parallel_size=1,
|
||||
model={"use_remove_padding": False},
|
||||
optim=optim,
|
||||
)
|
||||
assert valid_config_no_sp.ulysses_sequence_parallel_size == 1
|
67
tests/workers/config/test_engine_config_on_cpu.py
Normal file
67
tests/workers/config/test_engine_config_on_cpu.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from verl.workers.config.engine import FSDPEngineConfig, McoreEngineConfig
|
||||
|
||||
|
||||
class TestMcoreEngineConfig:
|
||||
def test_default_values(self):
|
||||
config = McoreEngineConfig()
|
||||
assert config.tensor_model_parallel_size == 1
|
||||
assert config.sequence_parallel is False # Should be auto-corrected
|
||||
assert config.seed == 42
|
||||
|
||||
def test_post_init_validation(self):
|
||||
# Test TP size 1 forces sequence_parallel=False
|
||||
config = McoreEngineConfig(tensor_model_parallel_size=1)
|
||||
assert config.sequence_parallel is False
|
||||
|
||||
# Test TP >1 keeps sequence_parallel=True
|
||||
config = McoreEngineConfig(tensor_model_parallel_size=2)
|
||||
assert config.sequence_parallel is True
|
||||
|
||||
def test_mutable_fields(self):
|
||||
config = McoreEngineConfig()
|
||||
config.sequence_parallel = True # Should be mutable
|
||||
with pytest.raises(AttributeError):
|
||||
config.tensor_model_parallel_size = 2 # Frozen field
|
||||
|
||||
@pytest.mark.parametrize("offload_field", ["param_offload", "grad_offload", "optimizer_offload"])
|
||||
def test_offload_flags(self, offload_field):
|
||||
config = McoreEngineConfig(**{offload_field: True})
|
||||
assert getattr(config, offload_field) is True
|
||||
|
||||
|
||||
class TestFSDPEngineConfigCPU:
|
||||
def test_default_values(self):
|
||||
config = FSDPEngineConfig()
|
||||
assert config.param_offload is False
|
||||
assert config.optimizer_offload is False
|
||||
assert config.fsdp_size == -1
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offload_params",
|
||||
[{"param_offload": True}, {"optimizer_offload": True}, {"param_offload": True, "optimizer_offload": True}],
|
||||
)
|
||||
def test_offload_combinations(self, offload_params):
|
||||
config = FSDPEngineConfig(**offload_params)
|
||||
assert config.param_offload == offload_params.get("param_offload", False)
|
||||
assert config.optimizer_offload == offload_params.get("optimizer_offload", False)
|
||||
|
||||
def test_wrap_policy_configuration(self):
|
||||
test_policy = {"layer_class": "TransformerBlock"}
|
||||
config = FSDPEngineConfig(wrap_policy=test_policy)
|
||||
assert config.wrap_policy == test_policy
|
39
tests/workers/config/test_optim_config_on_cpu.py
Normal file
39
tests/workers/config/test_optim_config_on_cpu.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from verl.workers.config.optimizer import FSDPOptimizerConfig
|
||||
|
||||
|
||||
class TestFSDPOptimizerConfigCPU:
|
||||
def test_default_configuration(self):
|
||||
config = FSDPOptimizerConfig(lr=0.1)
|
||||
assert config.min_lr_ratio is None
|
||||
assert config.warmup_style == "constant"
|
||||
assert config.num_cycles == 0.5
|
||||
|
||||
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
|
||||
def test_valid_warmup_styles(self, warmup_style):
|
||||
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
|
||||
assert config.warmup_style == warmup_style
|
||||
|
||||
def test_invalid_warmup_style(self):
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1)
|
||||
|
||||
@pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5])
|
||||
def test_num_cycles_configuration(self, num_cycles):
|
||||
config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1)
|
||||
assert config.num_cycles == num_cycles
|
@ -51,7 +51,14 @@ def init_config(n_gpus_per_node) -> DictConfig:
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
||||
config = compose(config_name="ppo_trainer")
|
||||
config = compose(
|
||||
config_name="ppo_trainer",
|
||||
overrides=[
|
||||
"actor_rollout_ref.actor.use_dynamic_bsz=true",
|
||||
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
|
||||
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
|
||||
],
|
||||
)
|
||||
config.trainer.n_gpus_per_node = n_gpus_per_node
|
||||
config.data.train_batch_size = 128
|
||||
config.data.return_raw_chat = True
|
||||
@ -64,10 +71,6 @@ def init_config(n_gpus_per_node) -> DictConfig:
|
||||
config.actor_rollout_ref.rollout.response_length = 4096
|
||||
config.actor_rollout_ref.rollout.n = 16
|
||||
|
||||
# test sleep/wake_up with fsdp offload
|
||||
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
|
||||
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
@ -33,7 +33,14 @@ def init_config() -> DictConfig:
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
||||
config = compose(config_name="ppo_trainer")
|
||||
config = compose(
|
||||
config_name="ppo_trainer",
|
||||
overrides=[
|
||||
"actor_rollout_ref.actor.use_dynamic_bsz=true",
|
||||
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
|
||||
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
|
||||
],
|
||||
)
|
||||
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
config.actor_rollout_ref.model.path = model_path
|
||||
config.actor_rollout_ref.rollout.mode = "async"
|
||||
@ -41,10 +48,6 @@ def init_config() -> DictConfig:
|
||||
config.actor_rollout_ref.rollout.prompt_length = 4096
|
||||
config.actor_rollout_ref.rollout.response_length = 4096
|
||||
|
||||
# test sleep/wake_up with fsdp offload
|
||||
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
|
||||
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
@ -13,33 +13,28 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
fields, # Import the fields function to inspect dataclass fields
|
||||
)
|
||||
from dataclasses import FrozenInstanceError, dataclass, field, fields
|
||||
from typing import Any
|
||||
|
||||
|
||||
# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary
|
||||
@dataclass
|
||||
class BaseConfig(collections.abc.Mapping):
|
||||
"""The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config.
|
||||
"""The BaseConfig provides dict-like interface for a dataclass config.
|
||||
|
||||
The BaseConfig class implements the Mapping Abstract Base Class.
|
||||
By default all fields in the config is not mutable, unless specified in
|
||||
"_mutable_fields". The BaseConfig class implements the Mapping Abstract Base Class.
|
||||
This allows instances of this class to be used like dictionaries.
|
||||
"""
|
||||
|
||||
_mutable_fields = {"extra"}
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __setattr__(self, name: str, value):
|
||||
# if the field already exists (i.e. was set in __init__)
|
||||
# and is in our frozen list, block assignment
|
||||
if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__:
|
||||
from dataclasses import FrozenInstanceError
|
||||
|
||||
"""Set the value of an attribute. Check if the attr is mutable before setting the value."""
|
||||
# If the field already exists, it's considered frozen unless it's in _mutable_fields
|
||||
if name in self.__dict__ and name not in getattr(self, "_mutable_fields", set()):
|
||||
raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified")
|
||||
# otherwise do the normal thing
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
|
@ -900,6 +900,50 @@ class DataProto:
|
||||
meta_info=self.meta_info,
|
||||
)
|
||||
|
||||
def get_data_info(self) -> str:
|
||||
"""Return formatted information about stored data with nested type details.
|
||||
|
||||
Returns:
|
||||
str: Formatted string showing tensor details and recursive metadata types
|
||||
"""
|
||||
info = ["batch"]
|
||||
|
||||
for key, tensor in self.batch.items():
|
||||
if hasattr(tensor, "shape") and hasattr(tensor, "dtype") and hasattr(tensor, "device"):
|
||||
info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype}) {tensor.device}")
|
||||
elif hasattr(tensor, "shape") and hasattr(tensor, "dtype"):
|
||||
info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype})")
|
||||
else:
|
||||
info.append(f" {key}: {type(tensor).__name__}")
|
||||
|
||||
info.append("non_tensor_batch")
|
||||
for key, array in self.non_tensor_batch.items():
|
||||
info.append(f" {key}: ndarray{array.shape} ({array.dtype})")
|
||||
|
||||
info.append("meta_info")
|
||||
for k, v in self.meta_info.items():
|
||||
type_info = self._get_type_info(v)
|
||||
info.append(f" {k}: {type_info}")
|
||||
|
||||
return "\n".join(info)
|
||||
|
||||
def _get_type_info(self, value):
|
||||
"""Recursively get type information for nested structures"""
|
||||
if isinstance(value, list):
|
||||
elem_types = {self._get_type_info(v) for v in value[:3]}
|
||||
return f"list[{'|'.join(elem_types) if elem_types else '...'}]"
|
||||
if isinstance(value, tuple):
|
||||
elem_types = [self._get_type_info(v) for v in value]
|
||||
return f"tuple({', '.join(elem_types)})"
|
||||
if isinstance(value, dict):
|
||||
if not value:
|
||||
return "dict"
|
||||
k, v = next(iter(value.items()))
|
||||
return f"dict[{self._get_type_info(k)}: {self._get_type_info(v)}]"
|
||||
if isinstance(value, np.ndarray):
|
||||
return f"ndarray{value.shape} ({value.dtype})"
|
||||
return type(value).__name__
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataProtoFuture:
|
||||
|
@ -12,15 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .algorithm import AlgoConfig, FilterGroupsConfig, KLControlConfig, PFPPOConfig
|
||||
from .config import CriticConfig, FSDPCriticConfig, MegatronCriticConfig
|
||||
from .algorithm import * # noqa
|
||||
from .config import * # noqa
|
||||
from . import config, algorithm
|
||||
|
||||
__all__ = [
|
||||
"AlgoConfig",
|
||||
"CriticConfig",
|
||||
"FilterGroupsConfig",
|
||||
"FSDPCriticConfig",
|
||||
"KLControlConfig",
|
||||
"MegatronCriticConfig",
|
||||
"PFPPOConfig",
|
||||
]
|
||||
__all__ = config.__all__ + algorithm.__all__
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
_target_: verl.workers.config.McoreActorConfig
|
||||
strategy: megatron
|
||||
ppo_mini_batch_size: 256
|
||||
ppo_micro_batch_size: null
|
||||
@ -15,6 +16,7 @@ actor_rollout_ref:
|
||||
clip_ratio_low: 0.2
|
||||
clip_ratio_high: 0.2
|
||||
policy_loss:
|
||||
_target_: verl.workers.config.PolicyLossConfig
|
||||
loss_mode: vanilla
|
||||
clip_cov_ratio: 0.0002
|
||||
clip_cov_lb: 1.0
|
||||
@ -31,6 +33,7 @@ actor_rollout_ref:
|
||||
ppo_epochs: 1
|
||||
shuffle: false
|
||||
checkpoint:
|
||||
_target_: verl.trainer.config.CheckpointConfig
|
||||
save_contents:
|
||||
- model
|
||||
- optimizer
|
||||
@ -42,10 +45,11 @@ actor_rollout_ref:
|
||||
lr_warmup_steps_ratio: 0.0
|
||||
total_training_steps: -1
|
||||
weight_decay: 0.01
|
||||
lr_warmup_steps: -1
|
||||
_target_: verl.workers.config.McoreOptimizerConfig
|
||||
optimizer: adam
|
||||
clip_grad: 1.0
|
||||
lr_warmup_init: 0.0
|
||||
lr_warmup_steps: null
|
||||
lr_decay_steps: null
|
||||
lr_decay_style: constant
|
||||
min_lr: 0.0
|
||||
@ -53,9 +57,11 @@ actor_rollout_ref:
|
||||
lr_wsd_decay_style: exponential
|
||||
lr_wsd_decay_steps: null
|
||||
use_checkpoint_opt_param_scheduler: false
|
||||
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
|
||||
data_loader_seed: null
|
||||
load_weight: true
|
||||
megatron:
|
||||
_target_: verl.workers.config.McoreEngineConfig
|
||||
param_offload: false
|
||||
grad_offload: false
|
||||
optimizer_offload: false
|
||||
@ -92,6 +98,7 @@ actor_rollout_ref:
|
||||
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
||||
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
|
||||
megatron:
|
||||
_target_: verl.workers.config.MegatronEngineConfig
|
||||
param_offload: false
|
||||
tensor_model_parallel_size: 1
|
||||
expert_model_parallel_size: 1
|
||||
@ -279,17 +286,20 @@ data:
|
||||
path: null
|
||||
name: null
|
||||
critic:
|
||||
_target_: verl.workers.config.McoreCriticConfig
|
||||
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
||||
strategy: megatron
|
||||
enable: null
|
||||
optim:
|
||||
lr: 1.0e-05
|
||||
lr_warmup_steps_ratio: 0.0
|
||||
total_training_steps: -1
|
||||
weight_decay: 0.01
|
||||
lr_warmup_steps: -1
|
||||
_target_: verl.workers.config.McoreOptimizerConfig
|
||||
optimizer: adam
|
||||
lr: 1.0e-06
|
||||
clip_grad: 1.0
|
||||
lr_warmup_init: 0.0
|
||||
lr_warmup_steps: null
|
||||
lr_decay_steps: null
|
||||
lr_decay_style: linear
|
||||
min_lr: 0.0
|
||||
@ -306,6 +316,7 @@ critic:
|
||||
freeze_moe_router: false
|
||||
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
|
||||
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
|
||||
_target_: verl.trainer.config.BaseModelConfig
|
||||
ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}
|
||||
ppo_micro_batch_size: null
|
||||
ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}
|
||||
@ -317,6 +328,7 @@ critic:
|
||||
cliprange_value: 0.5
|
||||
loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
|
||||
checkpoint:
|
||||
_target_: verl.trainer.config.CheckpointConfig
|
||||
save_contents:
|
||||
- model
|
||||
- optimizer
|
||||
@ -328,9 +340,9 @@ critic:
|
||||
discrete: false
|
||||
all_ranks: false
|
||||
ranks: []
|
||||
_target_: verl.trainer.config.MegatronCriticConfig
|
||||
nccl_timeout: 600
|
||||
megatron:
|
||||
_target_: verl.workers.config.McoreEngineConfig
|
||||
param_offload: false
|
||||
grad_offload: false
|
||||
optimizer_offload: false
|
||||
@ -376,6 +388,7 @@ reward_model:
|
||||
ranks: []
|
||||
nccl_timeout: 600
|
||||
megatron:
|
||||
_target_: verl.workers.config.MegatronEngineConfig
|
||||
param_offload: false
|
||||
tensor_model_parallel_size: 1
|
||||
expert_model_parallel_size: 1
|
||||
@ -410,7 +423,6 @@ algorithm:
|
||||
target_kl: 0.1
|
||||
use_pf_ppo: false
|
||||
pf_ppo:
|
||||
_target_: verl.trainer.config.PFPPOConfig
|
||||
reweight_method: pow
|
||||
weight_pow: 2.0
|
||||
ray_init:
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
_target_: verl.workers.config.FSDPActorConfig
|
||||
strategy: fsdp
|
||||
ppo_mini_batch_size: 256
|
||||
ppo_micro_batch_size: null
|
||||
@ -15,6 +16,7 @@ actor_rollout_ref:
|
||||
clip_ratio_low: 0.2
|
||||
clip_ratio_high: 0.2
|
||||
policy_loss:
|
||||
_target_: verl.workers.config.PolicyLossConfig
|
||||
loss_mode: vanilla
|
||||
clip_cov_ratio: 0.0002
|
||||
clip_cov_lb: 1.0
|
||||
@ -31,25 +33,30 @@ actor_rollout_ref:
|
||||
ppo_epochs: 1
|
||||
shuffle: false
|
||||
checkpoint:
|
||||
_target_: verl.trainer.config.CheckpointConfig
|
||||
save_contents:
|
||||
- model
|
||||
- optimizer
|
||||
- extra
|
||||
load_contents: ${.save_contents}
|
||||
async_save: false
|
||||
optim:
|
||||
lr: 1.0e-06
|
||||
lr_warmup_steps_ratio: 0.0
|
||||
total_training_steps: -1
|
||||
weight_decay: 0.01
|
||||
lr_warmup_steps: -1
|
||||
_target_: verl.workers.config.FSDPOptimizerConfig
|
||||
min_lr_ratio: 0.0
|
||||
num_cycles: 0.5
|
||||
warmup_style: constant
|
||||
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
|
||||
grad_clip: 1.0
|
||||
ulysses_sequence_parallel_size: 1
|
||||
entropy_from_logits_with_chunking: false
|
||||
entropy_checkpointing: false
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
wrap_policy:
|
||||
min_num_params: 0
|
||||
param_offload: false
|
||||
@ -58,6 +65,7 @@ actor_rollout_ref:
|
||||
reshard_after_forward: true
|
||||
fsdp_size: -1
|
||||
forward_prefetch: false
|
||||
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
|
||||
ref:
|
||||
strategy: ${actor_rollout_ref.actor.strategy}
|
||||
use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
|
||||
@ -66,11 +74,12 @@ actor_rollout_ref:
|
||||
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
||||
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
wrap_policy:
|
||||
min_num_params: 0
|
||||
param_offload: false
|
||||
reshard_after_forward: true
|
||||
forward_prefetch: false
|
||||
wrap_policy:
|
||||
min_num_params: 0
|
||||
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
|
||||
entropy_from_logits_with_chunking: false
|
||||
entropy_checkpointing: false
|
||||
@ -249,13 +258,17 @@ data:
|
||||
path: null
|
||||
name: null
|
||||
critic:
|
||||
_target_: verl.workers.config.FSDPCriticConfig
|
||||
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
||||
strategy: fsdp
|
||||
enable: null
|
||||
optim:
|
||||
lr: 1.0e-05
|
||||
lr_warmup_steps_ratio: 0.0
|
||||
total_training_steps: -1
|
||||
weight_decay: 0.01
|
||||
lr: 1.0e-05
|
||||
lr_warmup_steps: -1
|
||||
_target_: verl.workers.config.FSDPOptimizerConfig
|
||||
min_lr_ratio: null
|
||||
warmup_style: constant
|
||||
model:
|
||||
@ -264,11 +277,13 @@ critic:
|
||||
override_config: {}
|
||||
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
|
||||
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
|
||||
_target_: verl.workers.config.FSDPCriticModelCfg
|
||||
use_shm: false
|
||||
enable_gradient_checkpointing: true
|
||||
enable_activation_offload: false
|
||||
use_remove_padding: false
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
param_offload: false
|
||||
optimizer_offload: false
|
||||
offload_policy: false
|
||||
@ -291,17 +306,18 @@ critic:
|
||||
cliprange_value: 0.5
|
||||
loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
|
||||
checkpoint:
|
||||
_target_: verl.trainer.config.CheckpointConfig
|
||||
save_contents:
|
||||
- model
|
||||
- optimizer
|
||||
- extra
|
||||
load_contents: ${.save_contents}
|
||||
async_save: false
|
||||
profiler:
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
discrete: false
|
||||
all_ranks: false
|
||||
ranks: []
|
||||
_target_: verl.trainer.config.FSDPCriticConfig
|
||||
forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}
|
||||
forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}
|
||||
ulysses_sequence_parallel_size: 1
|
||||
@ -318,6 +334,7 @@ reward_model:
|
||||
use_remove_padding: false
|
||||
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
wrap_policy:
|
||||
min_num_params: 0
|
||||
param_offload: false
|
||||
@ -360,7 +377,6 @@ algorithm:
|
||||
target_kl: 0.1
|
||||
use_pf_ppo: false
|
||||
pf_ppo:
|
||||
_target_: verl.trainer.config.PFPPOConfig
|
||||
reweight_method: pow
|
||||
weight_pow: 2.0
|
||||
ray_init:
|
||||
|
@ -4,6 +4,9 @@
|
||||
# 3. Inline comments (after a field on the same line) are not allowed.
|
||||
# 4. Indentation level is respected for nested fields.
|
||||
|
||||
# Target class for this configuration
|
||||
_target_: verl.workers.config.ActorConfig
|
||||
|
||||
# the abstract actor configs
|
||||
# fsdp, fsdp2 or megatron. must be set.
|
||||
strategy: ???
|
||||
@ -38,6 +41,9 @@ clip_ratio_high: 0.2
|
||||
# policy loss config
|
||||
policy_loss:
|
||||
|
||||
# # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.PolicyLossConfig
|
||||
|
||||
# Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617
|
||||
loss_mode: "vanilla"
|
||||
|
||||
@ -87,6 +93,9 @@ shuffle: false
|
||||
# checkpoint configs
|
||||
checkpoint:
|
||||
|
||||
# Target dataclass for this configuration
|
||||
_target_: verl.trainer.config.CheckpointConfig
|
||||
|
||||
# What to include in saved checkpoints
|
||||
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
||||
save_contents: ['model', 'optimizer', 'extra']
|
||||
@ -95,6 +104,9 @@ checkpoint:
|
||||
# .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg
|
||||
load_contents: ${.save_contents}
|
||||
|
||||
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
|
||||
async_save: False
|
||||
|
||||
# optimizer configs
|
||||
optim:
|
||||
|
||||
@ -109,3 +121,10 @@ optim:
|
||||
|
||||
# Weight decay
|
||||
weight_decay: 0.01
|
||||
|
||||
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps: -1
|
||||
|
||||
|
||||
# Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)
|
||||
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
|
||||
|
@ -13,6 +13,9 @@ defaults:
|
||||
# load the reference default config, then apply the fields in the current yaml
|
||||
- _self_
|
||||
|
||||
# Target class for this configuration
|
||||
_target_: verl.workers.config.FSDPActorConfig
|
||||
|
||||
# TODO(haibin.lin): switch to fsdp2
|
||||
strategy: fsdp
|
||||
|
||||
@ -32,8 +35,8 @@ entropy_checkpointing: False
|
||||
# optimizer configs
|
||||
optim:
|
||||
|
||||
# Warmup steps; negative value delegates to lr_warmup_steps_ratio
|
||||
lr_warmup_steps: -1
|
||||
# Target class for this configuration
|
||||
_target_: verl.workers.config.FSDPOptimizerConfig
|
||||
|
||||
# Minimum LR ratio for cosine schedule
|
||||
min_lr_ratio: 0.0
|
||||
@ -47,6 +50,9 @@ optim:
|
||||
# configs for FSDP
|
||||
fsdp_config:
|
||||
|
||||
# Target class for this configuration
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
|
||||
# policy for wrapping the model
|
||||
wrap_policy:
|
||||
|
||||
@ -71,3 +77,6 @@ fsdp_config:
|
||||
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
|
||||
# before the current forward computation.
|
||||
forward_prefetch: False
|
||||
|
||||
# Whether to remove padding tokens in inputs during training
|
||||
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
|
||||
|
@ -4,18 +4,16 @@ defaults:
|
||||
# load the reference default config, then apply the fields in the current yaml
|
||||
- _self_
|
||||
|
||||
_target_: verl.workers.config.McoreActorConfig
|
||||
|
||||
strategy: megatron
|
||||
|
||||
data_loader_seed: null
|
||||
|
||||
load_weight: True
|
||||
|
||||
checkpoint:
|
||||
|
||||
async_save: False
|
||||
|
||||
optim:
|
||||
|
||||
_target_: verl.workers.config.McoreOptimizerConfig
|
||||
optimizer: adam
|
||||
|
||||
clip_grad: 1.0
|
||||
@ -23,9 +21,6 @@ optim:
|
||||
# initial learning rate for warmup, default to 0.0
|
||||
lr_warmup_init: 0.0
|
||||
|
||||
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps: null
|
||||
|
||||
lr_decay_steps: null
|
||||
|
||||
# select from constant/linear/cosine/inverse_square_root
|
||||
@ -47,10 +42,16 @@ optim:
|
||||
|
||||
megatron:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.McoreEngineConfig
|
||||
|
||||
# Whether to offload model parameters to CPU
|
||||
param_offload: False
|
||||
|
||||
# Whether to offload gradients to CPU
|
||||
grad_offload: False
|
||||
|
||||
# Whether to offload optimizer state to CPU
|
||||
optimizer_offload: False
|
||||
|
||||
tensor_model_parallel_size: 1
|
||||
@ -104,6 +105,7 @@ megatron:
|
||||
|
||||
# profile the actor model in `update_policy`
|
||||
profile:
|
||||
|
||||
# turn it on when you want to profile the actor model
|
||||
use_profile: False
|
||||
|
||||
|
@ -13,10 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from verl.base_config import BaseConfig
|
||||
|
||||
__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class KLControlConfig(BaseConfig):
|
||||
@ -31,29 +33,12 @@ class KLControlConfig(BaseConfig):
|
||||
target_kl (float): Target KL divergence for adaptive controller.
|
||||
"""
|
||||
|
||||
_frozen_fields = ["type", "kl_coef", "horizon", "target_kl"]
|
||||
type: str = "fixed"
|
||||
kl_coef: float = 0.001
|
||||
horizon: int = 10000
|
||||
target_kl: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class PFPPOConfig(BaseConfig):
|
||||
"""Configuration for preference feedback PPO.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
reweight_method (str): Method for reweighting samples. Can be "pow", "max_min", or "max_random".
|
||||
weight_pow (float): Power used for weight scaling in "pow" method.
|
||||
"""
|
||||
|
||||
_frozen_fields = ["reweight_method", "weight_pow"]
|
||||
reweight_method: str = "pow"
|
||||
weight_pow: float = 2.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterGroupsConfig(BaseConfig):
|
||||
"""Configuration for filter groups (used in DAPO and Entropy).
|
||||
@ -66,8 +51,6 @@ class FilterGroupsConfig(BaseConfig):
|
||||
max_num_gen_batches (int): Non-positive values mean no upper limit.
|
||||
"""
|
||||
|
||||
_frozen_fields = ["enable", "metric", "max_num_gen_batches"]
|
||||
|
||||
enable: bool = False
|
||||
metric: Optional[str] = None
|
||||
max_num_gen_batches: int = 0
|
||||
@ -88,20 +71,10 @@ class AlgoConfig(BaseConfig):
|
||||
kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full".
|
||||
kl_ctrl (KLControlConfig): KL control configuration.
|
||||
use_pf_ppo (bool): Whether to enable preference feedback PPO.
|
||||
pf_ppo (Optional[PFPPOConfig]): Preference feedback PPO settings.
|
||||
pf_ppo (dict[str, Any]): Preference feedback PPO settings.
|
||||
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
|
||||
"""
|
||||
|
||||
_frozen_fields = [
|
||||
"gamma",
|
||||
"lam",
|
||||
"adv_estimator",
|
||||
"norm_adv_by_std_in_grpo",
|
||||
"use_kl_in_reward",
|
||||
"kl_penalty",
|
||||
"use_pf_ppo",
|
||||
]
|
||||
|
||||
gamma: float = 1.0
|
||||
lam: float = 1.0
|
||||
adv_estimator: str = "gae"
|
||||
@ -110,5 +83,5 @@ class AlgoConfig(BaseConfig):
|
||||
kl_penalty: str = "kl"
|
||||
kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig)
|
||||
use_pf_ppo: bool = False
|
||||
pf_ppo: Optional[PFPPOConfig] = None
|
||||
pf_ppo: dict[str, Any] = field(default_factory=dict)
|
||||
filter_groups: Optional[FilterGroupsConfig] = None
|
||||
|
@ -17,110 +17,63 @@ from typing import Any, Optional
|
||||
|
||||
from verl.base_config import BaseConfig
|
||||
|
||||
__all__ = ["CheckpointConfig", "ProfileConfig", "BaseModelConfig"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticConfig(BaseConfig):
|
||||
"""Configuration for critic model training.
|
||||
class CheckpointConfig(BaseConfig):
|
||||
"""Configuration for model checkpointing.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
rollout_n (int): Number of rollouts per update (mirrors actor rollout_n).
|
||||
strategy (str): Strategy used for critic model training (fsdp, fsdp2, megatron).
|
||||
optim (Dict[str, Any]): Optimizer configuration including lr, weight_decay, etc.
|
||||
model (Dict[str, Any]): Model configuration including path, tokenizer_path, etc.
|
||||
ppo_mini_batch_size (int): PPO mini-batch size per update.
|
||||
ppo_micro_batch_size (Optional[int]): Global micro batch size (deprecated).
|
||||
ppo_micro_batch_size_per_gpu (Optional[int]): Local per-GPU micro batch size.
|
||||
use_dynamic_bsz (bool): Whether to automatically adjust batch size at runtime.
|
||||
ppo_max_token_len_per_gpu (int): Max tokens per GPU in one PPO batch.
|
||||
forward_max_token_len_per_gpu (int): Max token length per GPU in forward pass.
|
||||
ppo_epochs (int): Number of PPO epochs per batch.
|
||||
shuffle (bool): Shuffle training data across PPO epochs.
|
||||
cliprange_value (float): PPO value function clipping range.
|
||||
loss_agg_mode (str): Loss aggregation mode.
|
||||
checkpoint (Dict[str, Any]): Checkpoint configuration.
|
||||
profiler (Dict[str, Any]): Profiler configuration.
|
||||
save_contents (list[str]): What to include in saved checkpoints.
|
||||
Options: 'model', 'optimizer', 'extra', 'hf_model'.
|
||||
load_contents (list[str]): Contents to load from checkpoint. Defaults to same as save_contents.
|
||||
async_save (bool): Whether to save checkpoints asynchronously. Only implemented for Megatron as of now.
|
||||
"""
|
||||
|
||||
# For legacy reason configs related to batch_size are mutated in each role
|
||||
# In the future they will be added to frozen fields instead
|
||||
_frozen_fields = [
|
||||
"rollout_n",
|
||||
"strategy",
|
||||
"use_dynamic_bsz",
|
||||
"ppo_max_token_len_per_gpu",
|
||||
"forward_max_token_len_per_gpu",
|
||||
"ppo_epochs",
|
||||
"shuffle",
|
||||
"cliprange_value",
|
||||
"loss_agg_mode",
|
||||
]
|
||||
|
||||
rollout_n: int = 1
|
||||
strategy: str = "fsdp"
|
||||
optim: dict[str, Any] = field(default_factory=dict)
|
||||
model: dict[str, Any] = field(default_factory=dict)
|
||||
ppo_mini_batch_size: int = 1
|
||||
ppo_micro_batch_size: Optional[int] = None
|
||||
ppo_micro_batch_size_per_gpu: Optional[int] = None
|
||||
use_dynamic_bsz: bool = False
|
||||
ppo_max_token_len_per_gpu: int = 32768
|
||||
forward_max_token_len_per_gpu: int = 32768
|
||||
ppo_epochs: int = 1
|
||||
shuffle: bool = True
|
||||
cliprange_value: float = 0.5
|
||||
loss_agg_mode: str = "token-mean"
|
||||
checkpoint: dict[str, Any] = field(default_factory=dict)
|
||||
profiler: dict[str, Any] = field(default_factory=dict)
|
||||
save_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
|
||||
load_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
|
||||
async_save: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MegatronCriticConfig(CriticConfig):
|
||||
"""Configuration for Megatron-based critic model training.
|
||||
class ProfileConfig(BaseConfig):
|
||||
"""Configuration for profiling.
|
||||
|
||||
The inheritance from CriticConfig provides all base critic configuration plus Megatron-specific settings.
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
nccl_timeout (int): NCCL timeout in seconds for distributed operations.
|
||||
megatron (Dict[str, Any]): Megatron-specific parallelism settings.
|
||||
load_weight (bool): Whether to load initial weights.
|
||||
data_loader_seed (Optional[int]): Seed for data loader.
|
||||
use_profile (bool): Whether to enable profiling.
|
||||
profile_ranks (Optional[list[int]]): List of ranks to profile. None means all ranks.
|
||||
step_start (int): Starting step for profiling.
|
||||
step_end (int): Ending step for profiling.
|
||||
save_path (Optional[str]): Path to save profiling results.
|
||||
"""
|
||||
|
||||
_frozen_fields = CriticConfig._frozen_fields + [
|
||||
"nccl_timeout",
|
||||
"load_weight",
|
||||
"data_loader_seed",
|
||||
]
|
||||
|
||||
strategy: str = "megatron"
|
||||
nccl_timeout: int = 600
|
||||
megatron: dict[str, Any] = field(default_factory=dict)
|
||||
load_weight: bool = True
|
||||
data_loader_seed: Optional[int] = None
|
||||
use_profile: bool = False
|
||||
profile_ranks: Optional[list[int]] = None
|
||||
step_start: int = -1
|
||||
step_end: int = -1
|
||||
save_path: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPCriticConfig(CriticConfig):
|
||||
"""Configuration for FSDP-based critic model training.
|
||||
|
||||
The inheritance from CriticConfig provides all base critic configuration plus FSDP-specific settings.
|
||||
class BaseModelConfig(BaseConfig):
|
||||
"""Base configuration for a model.
|
||||
Contains core settings for loading and initializing a pretrained model checkpoint.
|
||||
|
||||
Args:
|
||||
forward_micro_batch_size (int): Forward-only batch size during inference (global).
|
||||
forward_micro_batch_size_per_gpu (int): Forward-only batch size during inference (per GPU).
|
||||
ulysses_sequence_parallel_size (int): Sequence parallelism size for Ulysses-style model parallelism.
|
||||
grad_clip (float): Gradient clipping for critic updates.
|
||||
path (str): Path to pretrained model weights.
|
||||
tokenizer_path (Optional[str]): Tokenizer path (defaults to actor's model path if not set).
|
||||
override_config (dict): Hugging Face config override.
|
||||
external_lib (Optional[str]): External model implementation (optional).
|
||||
trust_remote_code (bool): Whether to trust remote code from Hugging Face models.
|
||||
"""
|
||||
|
||||
_frozen_fields = CriticConfig._frozen_fields + [
|
||||
"ulysses_sequence_parallel_size",
|
||||
"grad_clip",
|
||||
]
|
||||
|
||||
strategy: str = "fsdp"
|
||||
forward_micro_batch_size: int = 1
|
||||
forward_micro_batch_size_per_gpu: int = 1
|
||||
ulysses_sequence_parallel_size: int = 1
|
||||
grad_clip: float = 1.0
|
||||
path: str = "~/models/deepseek-llm-7b-chat"
|
||||
tokenizer_path: Optional[str] = None
|
||||
override_config: dict[str, Any] = field(default_factory=dict)
|
||||
external_lib: Optional[str] = None
|
||||
trust_remote_code: bool = False
|
||||
|
@ -1,12 +1,23 @@
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.CriticConfig
|
||||
|
||||
# Number of rollouts per update (mirrors actor rollout_n)
|
||||
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
||||
|
||||
# fsdp or fsdp2 strategy used for critic model training
|
||||
strategy: ???
|
||||
|
||||
# whether to enable the critic worker.
|
||||
# by default it is only enabled if advantage estimator is gae
|
||||
# set it to True manually if you always want to enable critic worker
|
||||
enable: null
|
||||
|
||||
# optimizer configs
|
||||
optim:
|
||||
|
||||
# Learning rate
|
||||
lr: 1e-5
|
||||
|
||||
# Warmup steps ratio; total steps will be injected at runtime
|
||||
lr_warmup_steps_ratio: 0.0
|
||||
|
||||
@ -16,6 +27,10 @@ optim:
|
||||
# Weight decay
|
||||
weight_decay: 0.01
|
||||
|
||||
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps: -1
|
||||
|
||||
|
||||
# model config for the critic
|
||||
model:
|
||||
|
||||
@ -67,6 +82,9 @@ loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
|
||||
# checkpoint configs
|
||||
checkpoint:
|
||||
|
||||
# Target dataclass for this configuration
|
||||
_target_: verl.trainer.config.CheckpointConfig
|
||||
|
||||
# What to include in saved checkpoints
|
||||
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
||||
save_contents: ['model', 'optimizer', 'extra']
|
||||
@ -74,11 +92,14 @@ checkpoint:
|
||||
# What to include when loading checkpoints
|
||||
load_contents: ${.save_contents}
|
||||
|
||||
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
|
||||
async_save: False
|
||||
|
||||
# profiler configs
|
||||
# the corresponding dataclass is verl.utils.profiler.ProfilerConfig.
|
||||
profiler:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
|
||||
# True for each task has its own database, False for all tasks in one training step share one database.
|
||||
@ -89,6 +110,3 @@ profiler:
|
||||
|
||||
# The ranks that will be profiled. [] or [0,1,...]
|
||||
ranks: []
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
_target_: verl.trainer.config.CriticConfig
|
||||
|
@ -13,13 +13,17 @@ defaults:
|
||||
# load the reference default config, then apply the fields in the current yaml
|
||||
- _self_
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.FSDPCriticConfig
|
||||
|
||||
# distribution strategy. Options: fsdp (deprecating), fsdp2
|
||||
strategy: fsdp
|
||||
|
||||
# optimizer configs
|
||||
optim:
|
||||
|
||||
# Learning rate
|
||||
lr: 1e-5
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.FSDPOptimizerConfig
|
||||
|
||||
# Minimum LR ratio for cosine schedule
|
||||
min_lr_ratio: null
|
||||
@ -30,6 +34,9 @@ optim:
|
||||
# model config for the critic
|
||||
model:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.FSDPCriticModelCfg
|
||||
|
||||
# Whether to use shared memory for loading the model
|
||||
use_shm: False
|
||||
|
||||
@ -45,6 +52,9 @@ model:
|
||||
# FSDP-specific config
|
||||
fsdp_config:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
|
||||
# Whether to offload model parameters to CPU
|
||||
param_offload: False
|
||||
|
||||
@ -90,6 +100,3 @@ ulysses_sequence_parallel_size: 1
|
||||
|
||||
# Gradient clipping for critic updates
|
||||
grad_clip: 1.0
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
_target_: verl.trainer.config.FSDPCriticConfig
|
||||
|
@ -7,6 +7,9 @@ defaults:
|
||||
# load the reference default config, then apply the fields in the current yaml
|
||||
- _self_
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.McoreCriticConfig
|
||||
|
||||
strategy: megatron
|
||||
|
||||
# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
|
||||
@ -15,21 +18,18 @@ nccl_timeout: 600
|
||||
# optimizer configs
|
||||
optim:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.McoreOptimizerConfig
|
||||
|
||||
# select optimizer, default is Adam
|
||||
optimizer: adam
|
||||
|
||||
# Learning rate
|
||||
lr: 1e-6
|
||||
|
||||
# Clip gradients norm
|
||||
clip_grad: 1.0
|
||||
|
||||
# initial learning rate for warmup, default to 0.0
|
||||
lr_warmup_init: 0.0
|
||||
|
||||
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps: null
|
||||
|
||||
lr_decay_steps: null
|
||||
|
||||
# select from constant/linear/cosine/inverse_square_root
|
||||
@ -53,6 +53,9 @@ optim:
|
||||
# model config for the critic
|
||||
model:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.BaseModelConfig
|
||||
|
||||
# override default empty mapping
|
||||
override_config:
|
||||
|
||||
@ -65,6 +68,9 @@ model:
|
||||
# megatron-specific parallelism settings
|
||||
megatron:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.workers.config.McoreEngineConfig
|
||||
|
||||
# Whether to offload model parameters to CPU
|
||||
param_offload: False
|
||||
|
||||
@ -121,10 +127,3 @@ load_weight: True
|
||||
|
||||
# seed for data loader
|
||||
data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}
|
||||
|
||||
# Asynchronous checkpoint saving
|
||||
checkpoint:
|
||||
async_save: False
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
_target_: verl.trainer.config.MegatronCriticConfig
|
||||
|
@ -68,7 +68,7 @@ custom_reward_function:
|
||||
name: compute_score
|
||||
|
||||
algorithm:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.AlgoConfig
|
||||
gamma: 1.0
|
||||
lam: 1.0
|
||||
@ -77,7 +77,7 @@ algorithm:
|
||||
use_kl_in_reward: False
|
||||
kl_penalty: kl # how to estimate kl divergence
|
||||
kl_ctrl:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.KLControlConfig
|
||||
type: fixed
|
||||
kl_coef: 0.001
|
||||
@ -85,8 +85,6 @@ algorithm:
|
||||
target_kl: 0.1
|
||||
use_pf_ppo: False
|
||||
pf_ppo:
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
_target_: verl.trainer.config.PFPPOConfig
|
||||
reweight_method: pow # ["pow", "max_min", "max_random"]
|
||||
weight_pow: 2.0
|
||||
|
||||
|
@ -112,7 +112,7 @@ actor_rollout_ref:
|
||||
# profiler configs
|
||||
profiler:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
|
||||
# True for each task has its own database, False for all tasks in one training step share one database.
|
||||
@ -137,7 +137,7 @@ custom_reward_function:
|
||||
# config for the algorithm
|
||||
algorithm:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.AlgoConfig
|
||||
|
||||
# Discount factor for future rewards
|
||||
@ -161,7 +161,7 @@ algorithm:
|
||||
# KL control configuration
|
||||
kl_ctrl:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
||||
_target_: verl.trainer.config.KLControlConfig
|
||||
|
||||
# KL control type: "fixed" or "adaptive"
|
||||
@ -182,9 +182,6 @@ algorithm:
|
||||
# Preference feedback PPO settings
|
||||
pf_ppo:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
_target_: verl.trainer.config.PFPPOConfig
|
||||
|
||||
# Method for reweighting samples: "pow", "max_min", or "max_random"
|
||||
reweight_method: pow
|
||||
|
||||
|
@ -10,6 +10,15 @@ defaults:
|
||||
# config for FSDP strategy
|
||||
fsdp_config:
|
||||
|
||||
# Target class for this configuration
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
|
||||
# the wrap policy for FSDP model
|
||||
wrap_policy:
|
||||
|
||||
# minimum number of params in a wrapped module
|
||||
min_num_params: 0
|
||||
|
||||
# whether to offload parameters in FSDP
|
||||
param_offload: False
|
||||
|
||||
@ -21,12 +30,6 @@ fsdp_config:
|
||||
# before the current forward computation.
|
||||
forward_prefetch: False
|
||||
|
||||
# the wrap policy for FSDP model
|
||||
wrap_policy:
|
||||
|
||||
# minimum number of params in a wrapped module
|
||||
min_num_params: 0
|
||||
|
||||
# sequence parallel size
|
||||
# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1
|
||||
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
|
||||
|
@ -7,45 +7,27 @@ defaults:
|
||||
strategy: megatron
|
||||
|
||||
megatron:
|
||||
|
||||
_target_: verl.workers.config.MegatronEngineConfig
|
||||
param_offload: False
|
||||
|
||||
tensor_model_parallel_size: 1
|
||||
|
||||
expert_model_parallel_size: 1
|
||||
|
||||
expert_tensor_parallel_size: None
|
||||
|
||||
pipeline_model_parallel_size: 1
|
||||
|
||||
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
|
||||
|
||||
context_parallel_size: 1
|
||||
|
||||
sequence_parallel: True
|
||||
|
||||
use_distributed_optimizer: False
|
||||
|
||||
use_dist_checkpointing: False
|
||||
|
||||
dist_checkpointing_path: null
|
||||
|
||||
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
|
||||
|
||||
override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}
|
||||
|
||||
use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}
|
||||
|
||||
profile:
|
||||
|
||||
use_profile: False
|
||||
|
||||
profile_ranks: null
|
||||
|
||||
step_start: -1
|
||||
|
||||
step_end: -1
|
||||
|
||||
save_path: null
|
||||
|
||||
load_weight: True
|
@ -29,8 +29,12 @@ model:
|
||||
# FSDP-specific config
|
||||
fsdp_config:
|
||||
|
||||
# Target configuration dataclass
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
|
||||
# Policy for wrapping layers with FSDP
|
||||
wrap_policy:
|
||||
|
||||
# Minimum number of parameters to trigger wrapping
|
||||
min_num_params: 0
|
||||
|
||||
|
@ -15,6 +15,10 @@ nccl_timeout: 600
|
||||
|
||||
# Megatron parallelism & checkpointing config
|
||||
megatron:
|
||||
|
||||
# Target configuration dataclass
|
||||
_target_: verl.workers.config.MegatronEngineConfig
|
||||
|
||||
# Whether to offload model parameters to CPU
|
||||
param_offload: False
|
||||
|
||||
|
@ -12,8 +12,8 @@ strategy: ???
|
||||
# model config for reward scoring
|
||||
model:
|
||||
|
||||
# Input tokenizer. If the reward model’s chat template is inconsistent with the policy,
|
||||
# we need to first decode to plaintext, then apply the rm’s chat_template.
|
||||
# Input tokenizer. If the reward model's chat template is inconsistent with the policy,
|
||||
# we need to first decode to plaintext, then apply the rm's chat_template.
|
||||
# Then score with RM. If chat_templates are consistent, it can be set to null.
|
||||
# set this to null if the chat template is identical
|
||||
input_tokenizer: ${actor_rollout_ref.model.path}
|
||||
@ -68,7 +68,7 @@ sandbox_fusion:
|
||||
# profiler configs
|
||||
profiler:
|
||||
|
||||
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint
|
||||
# hint for the target config dataclass
|
||||
_target_: verl.utils.profiler.ProfilerConfig
|
||||
|
||||
# True for each task has its own database, False for all tasks in one training step share one database.
|
||||
|
@ -58,7 +58,6 @@ trainer:
|
||||
total_training_steps: null
|
||||
logger: [ 'console', 'wandb' ]
|
||||
seed: 1
|
||||
|
||||
save_freq: -1
|
||||
test_freq: -1
|
||||
nnodes: 1
|
||||
|
@ -21,6 +21,7 @@ This trainer supports model-agonistic model initialization with huggingface
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
@ -53,6 +54,7 @@ from verl.trainer.ppo.metric_utils import (
|
||||
)
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
|
||||
from verl.utils.config import omega_conf_to_dataclass
|
||||
from verl.utils.debug import marked_timer
|
||||
from verl.utils.metric import (
|
||||
reduce_metrics,
|
||||
@ -256,8 +258,8 @@ def compute_advantage(
|
||||
if config.get("use_pf_ppo", False):
|
||||
data = core_algos.compute_pf_ppo_reweight_data(
|
||||
data,
|
||||
config.pf_ppo.reweight_method,
|
||||
config.pf_ppo.weight_pow,
|
||||
config.pf_ppo.get("reweight_method"),
|
||||
config.pf_ppo.get("weight_pow"),
|
||||
)
|
||||
elif adv_estimator == AdvantageEstimator.GRPO:
|
||||
# Initialize the mask for GRPO calculation
|
||||
@ -369,21 +371,17 @@ class RayPPOTrainer:
|
||||
if self.config.algorithm.use_kl_in_reward:
|
||||
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
|
||||
|
||||
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
|
||||
if config.critic.enable is not None:
|
||||
self.use_critic = bool(config.critic.enable)
|
||||
elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
|
||||
self.use_critic = True
|
||||
elif self.config.algorithm.adv_estimator in [
|
||||
AdvantageEstimator.GRPO,
|
||||
AdvantageEstimator.GRPO_PASSK,
|
||||
AdvantageEstimator.REINFORCE_PLUS_PLUS,
|
||||
AdvantageEstimator.REMAX,
|
||||
AdvantageEstimator.RLOO,
|
||||
AdvantageEstimator.OPO,
|
||||
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
|
||||
AdvantageEstimator.GPG,
|
||||
]:
|
||||
self.use_critic = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
warnings.warn(
|
||||
"Disabled critic as algorithm.adv_estimator != gae. "
|
||||
"If it is not intended, please set critic.enable=True",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.use_critic = False
|
||||
|
||||
self._validate_config()
|
||||
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
||||
@ -434,8 +432,6 @@ class RayPPOTrainer:
|
||||
ValueError: If both parameters are set or neither is set.
|
||||
"""
|
||||
settings = {
|
||||
"actor_rollout_ref.actor": "micro_batch_size",
|
||||
"critic": "micro_batch_size",
|
||||
"reward_model": "micro_batch_size",
|
||||
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
|
||||
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
|
||||
@ -456,14 +452,11 @@ class RayPPOTrainer:
|
||||
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
|
||||
)
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.actor.ppo_micro_batch_size,
|
||||
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.actor",
|
||||
)
|
||||
# Actor validation done in ActorConfig.__post_init__ and validate()
|
||||
actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
|
||||
actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
if self.use_reference_policy:
|
||||
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
@ -479,66 +472,19 @@ class RayPPOTrainer:
|
||||
"actor_rollout_ref.rollout",
|
||||
)
|
||||
|
||||
if self.use_critic and not config.critic.use_dynamic_bsz:
|
||||
# Check for critic micro-batch size conflicts
|
||||
check_mutually_exclusive(
|
||||
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
|
||||
)
|
||||
|
||||
# Check for reward model micro-batch size conflicts
|
||||
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
|
||||
check_mutually_exclusive(
|
||||
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
|
||||
)
|
||||
|
||||
# Actor
|
||||
# check if train_batch_size is larger than ppo_mini_batch_size
|
||||
# if NOT dynamic_bsz, we must ensure:
|
||||
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
|
||||
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
|
||||
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
|
||||
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
|
||||
assert (
|
||||
config.actor_rollout_ref.actor.ppo_mini_batch_size
|
||||
% config.actor_rollout_ref.actor.ppo_micro_batch_size
|
||||
== 0
|
||||
)
|
||||
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
|
||||
|
||||
assert config.actor_rollout_ref.actor.loss_agg_mode in [
|
||||
"token-mean",
|
||||
"seq-mean-token-sum",
|
||||
"seq-mean-token-mean",
|
||||
"seq-mean-token-sum-norm",
|
||||
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
|
||||
|
||||
if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
|
||||
print("NOTICE: You have both enabled in-reward kl and kl loss.")
|
||||
|
||||
# critic
|
||||
if self.use_critic and not config.critic.use_dynamic_bsz:
|
||||
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
|
||||
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
|
||||
if config.critic.ppo_micro_batch_size is not None:
|
||||
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
|
||||
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
|
||||
|
||||
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"} and (
|
||||
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
|
||||
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
|
||||
):
|
||||
assert config.actor_rollout_ref.model.use_remove_padding, (
|
||||
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
|
||||
)
|
||||
|
||||
if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
|
||||
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
|
||||
assert config.critic.model.use_remove_padding, (
|
||||
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
|
||||
)
|
||||
if self.use_critic:
|
||||
critic_config = omega_conf_to_dataclass(config.critic)
|
||||
critic_config.validate(n_gpus, config.data.train_batch_size)
|
||||
|
||||
if config.data.get("val_batch_size", None) is not None:
|
||||
print(
|
||||
@ -850,7 +796,8 @@ class RayPPOTrainer:
|
||||
# create critic
|
||||
if self.use_critic:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
|
||||
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
|
||||
critic_cfg = omega_conf_to_dataclass(self.config.critic)
|
||||
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
|
||||
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
|
||||
|
||||
# create reference policy if needed
|
||||
|
@ -22,17 +22,17 @@ import torch.distributed
|
||||
from omegaconf import DictConfig
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from verl.trainer.config import CheckpointConfig
|
||||
from verl.utils.device import get_device_name, get_torch_device
|
||||
|
||||
|
||||
class BaseCheckpointManager:
|
||||
"""
|
||||
A checkpoint manager that saves and loads
|
||||
A checkpoint manager that saves and loads the following states in a SPMD way:
|
||||
- model
|
||||
- optimizer
|
||||
- lr_scheduler
|
||||
- extra_states
|
||||
in a SPMD way.
|
||||
|
||||
We save
|
||||
- sharded model states and optimizer states
|
||||
@ -46,7 +46,7 @@ class BaseCheckpointManager:
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
||||
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
|
||||
checkpoint_config: DictConfig = None,
|
||||
checkpoint_config: DictConfig | CheckpointConfig = None,
|
||||
):
|
||||
self.checkpoint_config = checkpoint_config
|
||||
checkpoint_load_contents = checkpoint_config.get("load_contents", None) if checkpoint_config else None
|
||||
|
@ -41,8 +41,9 @@ def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[
|
||||
|
||||
if dataclass_type is None:
|
||||
assert "_target_" in config, (
|
||||
"When dataclass_type is not provided, config must contain _target_."
|
||||
"See trainer/config/ppo_trainer.yaml algorithm section for an example."
|
||||
"When dataclass_type is not provided, config must contain _target_. "
|
||||
"See trainer/config/ppo_trainer.yaml algorithm section for an example. "
|
||||
f"Got config: {config}"
|
||||
)
|
||||
from hydra.utils import instantiate
|
||||
|
||||
@ -51,6 +52,9 @@ def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[
|
||||
if not is_dataclass(dataclass_type):
|
||||
raise ValueError(f"{dataclass_type} must be a dataclass")
|
||||
cfg = OmegaConf.create(config) # in case it's a dict
|
||||
# pop _target_ to avoid hydra instantiate error, as most dataclass do not have _target_
|
||||
if "_target_" in cfg:
|
||||
cfg.pop("_target_")
|
||||
cfg_from_dataclass = OmegaConf.structured(dataclass_type)
|
||||
# let cfg override the existing vals in `cfg_from_dataclass`
|
||||
cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg)
|
||||
|
@ -42,23 +42,25 @@ def get_megatron_optimizer_param_scheduler(
|
||||
"""
|
||||
Get the optimizer parameter scheduler for Megatron.
|
||||
"""
|
||||
lr_decay_steps = config.lr_decay_steps
|
||||
lr_warmup_steps = config.lr_warmup_steps
|
||||
if config.get("lr_decay_steps", None) is None:
|
||||
config.lr_decay_steps = config.total_training_steps
|
||||
lr_decay_steps = config.total_training_steps
|
||||
wsd_decay_steps = None
|
||||
if config.get("lr_wsd_decay_steps", None) is not None:
|
||||
wsd_decay_steps = config.lr_wsd_decay_steps
|
||||
if config.get("lr_warmup_steps_ratio", None) is not None and (
|
||||
config.get("lr_warmup_steps", None) is None or config.lr_warmup_steps <= 0
|
||||
):
|
||||
config.lr_warmup_steps = int(config.lr_warmup_steps_ratio * config.lr_decay_steps)
|
||||
lr_warmup_steps = int(config.lr_warmup_steps_ratio * lr_decay_steps)
|
||||
|
||||
opt_param_scheduler = OptimizerParamScheduler(
|
||||
optimizer,
|
||||
init_lr=config.lr_warmup_init,
|
||||
max_lr=config.lr,
|
||||
min_lr=config.min_lr,
|
||||
lr_warmup_steps=config.lr_warmup_steps,
|
||||
lr_decay_steps=config.lr_decay_steps,
|
||||
lr_warmup_steps=lr_warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
lr_decay_style=config.lr_decay_style,
|
||||
start_wd=config.weight_decay,
|
||||
end_wd=config.weight_decay,
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar
|
||||
|
||||
from verl.base_config import BaseConfig
|
||||
|
||||
@ -31,13 +30,8 @@ class ProfilerConfig(BaseConfig):
|
||||
ranks (list[int]): The ranks that will be profiled. Defaults to [].
|
||||
"""
|
||||
|
||||
# the fields expected to be frozen
|
||||
_frozen_fields: ClassVar[set[str]] = {"discrete", "all_ranks", "ranks"}
|
||||
|
||||
discrete: bool = False
|
||||
|
||||
all_ranks: bool = False
|
||||
|
||||
ranks: list[int] = field(default_factory=list)
|
||||
|
||||
def union(self, other: "ProfilerConfig") -> "ProfilerConfig":
|
||||
|
@ -17,6 +17,7 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from .config import ProfilerConfig
|
||||
|
||||
@ -40,6 +41,8 @@ class Profiler:
|
||||
|
||||
def __init__(self, config):
|
||||
# note : if we do not set use_profile, it will be set as None, so that all function will be skip
|
||||
if not isinstance(config, DictConfig):
|
||||
config = OmegaConf.create(config)
|
||||
self.config = config
|
||||
self.skip_prof = False
|
||||
self.saved = False
|
||||
|
@ -35,6 +35,7 @@ from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_b
|
||||
from verl.utils.torch_functional import logprobs_from_logits
|
||||
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
|
||||
from verl.workers.actor import BasePPOActor
|
||||
from verl.workers.config import ActorConfig
|
||||
|
||||
if is_cuda_available:
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
|
||||
@ -49,18 +50,27 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class DataParallelPPOActor(BasePPOActor):
|
||||
def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):
|
||||
"""FSDP DataParallel PPO Actor or Ref worker
|
||||
|
||||
Args:
|
||||
config (ActorConfig): Actor config
|
||||
actor_module (nn.Module): Actor or ref module
|
||||
actor_optimizer (torch.optim.Optimizer, optional): Actor optimizer. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ActorConfig, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):
|
||||
"""When optimizer is None, it is Reference Policy"""
|
||||
super().__init__(config)
|
||||
self.actor_module = actor_module
|
||||
self.actor_optimizer = actor_optimizer
|
||||
role = "Ref" if actor_optimizer is None else "Actor"
|
||||
|
||||
self.use_remove_padding = self.config.get("use_remove_padding", False)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(f"Actor use_remove_padding={self.use_remove_padding}")
|
||||
print(f"{role} use_remove_padding={self.use_remove_padding}")
|
||||
self.use_fused_kernels = self.config.get("use_fused_kernels", False)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(f"Actor use_fused_kernels={self.use_fused_kernels}")
|
||||
print(f"{role} use_fused_kernels={self.use_fused_kernels}")
|
||||
|
||||
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
|
||||
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
|
||||
|
21
verl/workers/config/__init__.py
Normal file
21
verl/workers/config/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .critic import * # noqa
|
||||
from .actor import * # noqa
|
||||
from .engine import * # noqa
|
||||
from .optimizer import * # noqa
|
||||
from . import actor, critic, engine, optimizer
|
||||
|
||||
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__
|
234
verl/workers/config/actor.py
Normal file
234
verl/workers/config/actor.py
Normal file
@ -0,0 +1,234 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from omegaconf import MISSING
|
||||
|
||||
from verl.base_config import BaseConfig
|
||||
from verl.trainer.config import CheckpointConfig
|
||||
|
||||
from .engine import FSDPEngineConfig, McoreEngineConfig
|
||||
from .optimizer import OptimizerConfig
|
||||
|
||||
__all__ = ["PolicyLossConfig", "ActorConfig", "FSDPActorConfig", "McoreActorConfig"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyLossConfig(BaseConfig):
|
||||
"""Configuration for policy loss computation.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
loss_mode (str): Loss function mode. Options: 'vanilla', 'clip-cov', 'kl-cov', 'gpg'.
|
||||
clip_cov_ratio (float): Ratio of tokens to be clipped for clip-cov loss.
|
||||
clip_cov_lb (float): Lower bound for clip-cov loss.
|
||||
clip_cov_ub (float): Upper bound for clip-cov loss.
|
||||
kl_cov_ratio (float): Ratio of tokens to be applied KL penalty for kl-cov loss.
|
||||
ppo_kl_coef (float): KL divergence penalty coefficient.
|
||||
"""
|
||||
|
||||
loss_mode: str = "vanilla"
|
||||
clip_cov_ratio: float = 0.0002
|
||||
clip_cov_lb: float = 1.0
|
||||
clip_cov_ub: float = 5.0
|
||||
kl_cov_ratio: float = 0.0002
|
||||
ppo_kl_coef: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorConfig(BaseConfig):
|
||||
"""Configuration for actor model training.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
strategy (str): Training strategy. Must be specified.
|
||||
ppo_mini_batch_size (int): Mini-batch size for PPO training.
|
||||
ppo_micro_batch_size (Optional[int]): Micro-batch size for PPO training.
|
||||
If None, uses ppo_micro_batch_size_per_gpu.
|
||||
ppo_micro_batch_size_per_gpu (Optional[int]): Micro-batch size per GPU for PPO training.
|
||||
use_dynamic_bsz (bool): Whether to use dynamic batch sizing.
|
||||
ppo_max_token_len_per_gpu (int): Maximum token length per GPU for PPO training.
|
||||
clip_ratio (float): PPO clipping ratio for policy loss.
|
||||
clip_ratio_low (float): Lower bound for PPO clipping ratio.
|
||||
clip_ratio_high (float): Upper bound for PPO clipping ratio.
|
||||
policy_loss (PolicyLossConfig): Configuration for policy loss computation.
|
||||
clip_ratio_c (float): Clipping ratio for critic loss.
|
||||
loss_agg_mode (str): Loss aggregation mode. Options: 'token-mean', 'sample-mean'.
|
||||
entropy_coeff (float): Entropy coefficient for regularization.
|
||||
use_kl_loss (bool): Whether to use KL divergence loss.
|
||||
use_torch_compile (bool): Whether to use torch.compile for optimization.
|
||||
kl_loss_coef (float): KL divergence loss coefficient.
|
||||
kl_loss_type (str): Type of KL loss to use.
|
||||
ppo_epochs (int): Number of PPO epochs per training step.
|
||||
shuffle (bool): Whether to shuffle data during training.
|
||||
checkpoint (CheckpointConfig): Configuration for checkpointing.
|
||||
optim (OptimizerConfig): Configuration for optimizer.
|
||||
use_fused_kernels (bool): Whether to use custom fused kernels (e.g., FlashAttention, fused MLP).
|
||||
"""
|
||||
|
||||
_mutable_fields = BaseConfig._mutable_fields | {
|
||||
"ppo_mini_batch_size",
|
||||
"ppo_micro_batch_size",
|
||||
"ppo_micro_batch_size_per_gpu",
|
||||
}
|
||||
|
||||
strategy: str = MISSING
|
||||
ppo_mini_batch_size: int = 256
|
||||
ppo_micro_batch_size: Optional[int] = None
|
||||
ppo_micro_batch_size_per_gpu: Optional[int] = None
|
||||
use_dynamic_bsz: bool = False
|
||||
ppo_max_token_len_per_gpu: int = 16384
|
||||
clip_ratio: float = 0.2
|
||||
clip_ratio_low: float = 0.2
|
||||
clip_ratio_high: float = 0.2
|
||||
policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)
|
||||
clip_ratio_c: float = 3.0
|
||||
loss_agg_mode: str = "token-mean"
|
||||
entropy_coeff: float = 0
|
||||
use_kl_loss: bool = False
|
||||
use_torch_compile: bool = True
|
||||
kl_loss_coef: float = 0.001
|
||||
kl_loss_type: str = "low_var_kl"
|
||||
ppo_epochs: int = 1
|
||||
shuffle: bool = False
|
||||
checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
|
||||
optim: OptimizerConfig = field(default_factory=OptimizerConfig)
|
||||
use_fused_kernels: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate actor configuration parameters."""
|
||||
assert self.strategy != MISSING
|
||||
if not self.use_dynamic_bsz:
|
||||
if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None:
|
||||
raise ValueError(
|
||||
"[actor] You have set both 'actor.ppo_micro_batch_size' AND 'actor.ppo_micro_batch_size_per_gpu'. "
|
||||
"Please remove 'actor.ppo_micro_batch_size' because only '*_ppo_micro_batch_size_per_gpu' is "
|
||||
"supported (the former is deprecated)."
|
||||
)
|
||||
else:
|
||||
assert not (self.ppo_micro_batch_size is None and self.ppo_micro_batch_size_per_gpu is None), (
|
||||
"[actor] Please set at least one of 'actor.ppo_micro_batch_size' or "
|
||||
"'actor.ppo_micro_batch_size_per_gpu' if use_dynamic_bsz is not enabled."
|
||||
)
|
||||
|
||||
valid_loss_agg_modes = [
|
||||
"token-mean",
|
||||
"seq-mean-token-sum",
|
||||
"seq-mean-token-mean",
|
||||
"seq-mean-token-sum-norm",
|
||||
]
|
||||
if self.loss_agg_mode not in valid_loss_agg_modes:
|
||||
raise ValueError(f"Invalid loss_agg_mode: {self.loss_agg_mode}")
|
||||
|
||||
def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):
|
||||
"""Validate actor configuration with runtime parameters."""
|
||||
if not self.use_dynamic_bsz:
|
||||
if train_batch_size < self.ppo_mini_batch_size:
|
||||
raise ValueError(
|
||||
f"train_batch_size ({train_batch_size}) must be >= "
|
||||
f"actor.ppo_mini_batch_size ({self.ppo_mini_batch_size})"
|
||||
)
|
||||
|
||||
sp_size = getattr(self, "ulysses_sequence_parallel_size", 1)
|
||||
if self.ppo_micro_batch_size is not None:
|
||||
if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0:
|
||||
raise ValueError(
|
||||
f"ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by "
|
||||
f"ppo_micro_batch_size ({self.ppo_micro_batch_size})"
|
||||
)
|
||||
if self.ppo_micro_batch_size * sp_size < n_gpus:
|
||||
raise ValueError(
|
||||
f"ppo_micro_batch_size ({self.ppo_micro_batch_size}) * "
|
||||
f"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
|
||||
"""Validate mutually exclusive micro batch size configuration options."""
|
||||
param = "ppo_micro_batch_size"
|
||||
param_per_gpu = f"{param}_per_gpu"
|
||||
|
||||
if mbs is None and mbs_per_gpu is None:
|
||||
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
|
||||
|
||||
if mbs is not None and mbs_per_gpu is not None:
|
||||
raise ValueError(
|
||||
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
|
||||
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class McoreActorConfig(ActorConfig):
|
||||
"""Configuration for Megatron actor models.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
strategy (str): Training strategy set to 'megatron' for Megatron parallelism.
|
||||
data_loader_seed (Optional[int]): Seed for data loader. If None, uses global seed.
|
||||
load_weight (bool): Whether to load model weights from checkpoint.
|
||||
megatron (dict[str, Any]): Configuration for Megatron parallelism settings.
|
||||
profile (dict[str, Any]): Configuration for profiling settings.
|
||||
"""
|
||||
|
||||
strategy: str = "megatron"
|
||||
data_loader_seed: Optional[int] = None
|
||||
load_weight: bool = True
|
||||
megatron: McoreEngineConfig = field(default_factory=McoreEngineConfig)
|
||||
profile: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPActorConfig(ActorConfig):
|
||||
"""Configuration for FSDP actor models.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
strategy (str): Training strategy set to 'fsdp' for Fully Sharded Data Parallel.
|
||||
grad_clip (float): Gradient clipping threshold.
|
||||
ulysses_sequence_parallel_size (int): Ulysses sequence parallel size for long sequences.
|
||||
entropy_from_logits_with_chunking (bool): Whether to compute entropy from logits
|
||||
with chunking for memory efficiency.
|
||||
entropy_checkpointing (bool): Whether to use gradient checkpointing for entropy computation.
|
||||
fsdp_config (dict[str, Any]): Configuration for FSDP settings.
|
||||
use_remove_padding (bool): Whether to remove padding tokens in inputs during training
|
||||
"""
|
||||
|
||||
strategy: str = "fsdp"
|
||||
grad_clip: float = 1.0
|
||||
ulysses_sequence_parallel_size: int = 1
|
||||
entropy_from_logits_with_chunking: bool = False
|
||||
entropy_checkpointing: bool = False
|
||||
fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||
use_remove_padding: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate FSDP actor configuration parameters."""
|
||||
super().__post_init__()
|
||||
|
||||
def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):
|
||||
"""Validate FSDP actor configuration with runtime parameters."""
|
||||
super().validate(n_gpus, train_batch_size, model_config)
|
||||
|
||||
if self.strategy in {"fsdp", "fsdp2"} and self.ulysses_sequence_parallel_size > 1:
|
||||
if model_config and not model_config.get("use_remove_padding", False):
|
||||
raise ValueError(
|
||||
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
|
||||
)
|
231
verl/workers/config/critic.py
Normal file
231
verl/workers/config/critic.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import MISSING
|
||||
|
||||
from verl.base_config import BaseConfig
|
||||
from verl.trainer.config import BaseModelConfig, CheckpointConfig
|
||||
from verl.utils.profiler import ProfilerConfig
|
||||
|
||||
from .engine import FSDPEngineConfig, McoreEngineConfig
|
||||
from .optimizer import OptimizerConfig
|
||||
|
||||
__all__ = ["CriticConfig", "FSDPCriticConfig", "McoreCriticConfig", "FSDPCriticModelCfg"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticConfig(BaseConfig):
|
||||
"""Configuration for critic model training.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
strategy (str): Strategy used for critic model training (fsdp, fsdp2, megatron).
|
||||
ppo_micro_batch_size_per_gpu (int): Local per-GPU micro batch size.
|
||||
rollout_n (int): Number of rollouts per update (mirrors actor rollout_n).
|
||||
optim (Dict[str, Any]): Optimizer configuration including lr, weight_decay, etc.
|
||||
model (Dict[str, Any]): Model configuration including path, tokenizer_path, etc.
|
||||
ppo_mini_batch_size (int): PPO mini-batch size per update.
|
||||
ppo_micro_batch_size (Optional[int]): Global micro batch size (deprecated).
|
||||
use_dynamic_bsz (bool): Whether to automatically adjust batch size at runtime.
|
||||
ppo_max_token_len_per_gpu (int): Max tokens per GPU in one PPO batch.
|
||||
forward_max_token_len_per_gpu (int): Max token length per GPU in forward pass.
|
||||
ppo_epochs (int): Number of PPO epochs per batch.
|
||||
shuffle (bool): Shuffle training data across PPO epochs.
|
||||
cliprange_value (float): PPO value function clipping range.
|
||||
loss_agg_mode (str): Loss aggregation mode.
|
||||
checkpoint (Dict[str, Any]): Checkpoint configuration.
|
||||
profiler (Dict[str, Any]): Profiler configuration.
|
||||
enable (Optional[bool]): Whether to enable the critic.
|
||||
"""
|
||||
|
||||
_mutable_fields = BaseConfig._mutable_fields | {
|
||||
"ppo_micro_batch_size_per_gpu",
|
||||
"ppo_mini_batch_size",
|
||||
"ppo_micro_batch_size",
|
||||
}
|
||||
|
||||
strategy: str = MISSING
|
||||
ppo_micro_batch_size_per_gpu: Optional[int] = None
|
||||
enable: Optional[bool] = None
|
||||
rollout_n: int = 1
|
||||
ppo_mini_batch_size: int = 1
|
||||
use_dynamic_bsz: bool = False
|
||||
ppo_max_token_len_per_gpu: int = 32768
|
||||
forward_max_token_len_per_gpu: int = 32768
|
||||
ppo_epochs: int = 1
|
||||
shuffle: bool = True
|
||||
cliprange_value: float = 0.5
|
||||
loss_agg_mode: str = "token-mean"
|
||||
ppo_micro_batch_size: Optional[int] = None
|
||||
optim: OptimizerConfig = field(default_factory=OptimizerConfig)
|
||||
model: BaseModelConfig = field(default_factory=BaseModelConfig)
|
||||
checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
|
||||
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate critic configuration parameters."""
|
||||
assert self.strategy != MISSING
|
||||
if not self.use_dynamic_bsz:
|
||||
self._check_mutually_exclusive(self.ppo_micro_batch_size, self.ppo_micro_batch_size_per_gpu, "critic")
|
||||
|
||||
if self.ppo_micro_batch_size is not None:
|
||||
if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0:
|
||||
raise ValueError(
|
||||
f"[critic] ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by "
|
||||
f"ppo_micro_batch_size ({self.ppo_micro_batch_size})"
|
||||
)
|
||||
|
||||
def validate(self, n_gpus: int, train_batch_size: int):
|
||||
"""Validate critic configuration with runtime parameters.
|
||||
|
||||
Args:
|
||||
n_gpus: Total number of GPUs available
|
||||
train_batch_size: Training batch size from data config
|
||||
"""
|
||||
if not self.use_dynamic_bsz:
|
||||
if train_batch_size < self.ppo_mini_batch_size:
|
||||
raise ValueError(
|
||||
f"train_batch_size ({train_batch_size}) must be >= "
|
||||
f"critic.ppo_mini_batch_size ({self.ppo_mini_batch_size})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
|
||||
"""Validate mutually exclusive micro batch size configuration options.
|
||||
|
||||
Ensures that users don't set both deprecated micro_batch_size and
|
||||
the new micro_batch_size_per_gpu parameters simultaneously.
|
||||
|
||||
Args:
|
||||
mbs: Deprecated micro batch size parameter value.
|
||||
mbs_per_gpu: New micro batch size per GPU parameter value.
|
||||
name (str): Configuration section name for error messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If both parameters are set or neither is set.
|
||||
"""
|
||||
param = "micro_batch_size"
|
||||
param_per_gpu = f"{param}_per_gpu"
|
||||
|
||||
if mbs is None and mbs_per_gpu is None:
|
||||
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
|
||||
|
||||
if mbs is not None and mbs_per_gpu is not None:
|
||||
raise ValueError(
|
||||
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
|
||||
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class McoreCriticConfig(CriticConfig):
|
||||
"""Configuration for Megatron-based critic model training.
|
||||
|
||||
The inheritance from CriticConfig provides all base critic configuration plus Megatron-specific settings.
|
||||
|
||||
Args:
|
||||
nccl_timeout (int): NCCL timeout in seconds for distributed operations.
|
||||
megatron (Dict[str, Any]): Megatron-specific parallelism settings.
|
||||
load_weight (bool): Whether to load initial weights.
|
||||
data_loader_seed (Optional[int]): Seed for data loader.
|
||||
"""
|
||||
|
||||
strategy: str = "megatron"
|
||||
nccl_timeout: int = 600
|
||||
megatron: McoreEngineConfig = field(default_factory=McoreEngineConfig)
|
||||
load_weight: bool = True
|
||||
data_loader_seed: Optional[int] = None
|
||||
|
||||
def validate(self, n_gpus: int, train_batch_size: int):
|
||||
"""Validate Megatron critic configuration with runtime parameters."""
|
||||
super().validate(n_gpus, train_batch_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPCriticConfig(CriticConfig):
|
||||
"""Configuration for FSDP-based critic model training.
|
||||
|
||||
The inheritance from CriticConfig provides all base critic configuration plus FSDP-specific settings.
|
||||
|
||||
Args:
|
||||
forward_micro_batch_size (int): Forward-only batch size during inference (global).
|
||||
forward_micro_batch_size_per_gpu (int): Forward-only batch size during inference (per GPU).
|
||||
ulysses_sequence_parallel_size (int): Sequence parallelism size for Ulysses-style model parallelism.
|
||||
grad_clip (float): Gradient clipping for critic updates.
|
||||
"""
|
||||
|
||||
_mutable_fields = CriticConfig._mutable_fields | {
|
||||
"forward_micro_batch_size",
|
||||
"forward_micro_batch_size_per_gpu",
|
||||
}
|
||||
|
||||
strategy: str = "fsdp"
|
||||
forward_micro_batch_size: int = 1
|
||||
forward_micro_batch_size_per_gpu: int = 1
|
||||
ulysses_sequence_parallel_size: int = 1
|
||||
grad_clip: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate FSDP critic configuration parameters."""
|
||||
super().__post_init__()
|
||||
|
||||
if self.strategy in {"fsdp", "fsdp2"}:
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
if not self.model.get("use_remove_padding", False):
|
||||
raise ValueError(
|
||||
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
|
||||
)
|
||||
|
||||
def validate(self, n_gpus: int, train_batch_size: int):
|
||||
"""Validate FSDP critic configuration with runtime parameters."""
|
||||
super().validate(n_gpus, train_batch_size)
|
||||
|
||||
if not self.use_dynamic_bsz:
|
||||
sp_size = self.ulysses_sequence_parallel_size
|
||||
if self.ppo_micro_batch_size is not None:
|
||||
if self.ppo_micro_batch_size * sp_size < n_gpus:
|
||||
raise ValueError(
|
||||
f"critic.ppo_micro_batch_size ({self.ppo_micro_batch_size}) * "
|
||||
f"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPCriticModelCfg(BaseModelConfig):
|
||||
"""FSDP-enabled critic model configuration.
|
||||
Inherits base critic settings and adds distributed-memory and LoRA options.
|
||||
|
||||
Args:
|
||||
use_shm (bool): Whether to use shared memory for loading the model.
|
||||
enable_activation_offload (bool): Offload activations to CPU to reduce GPU memory usage.
|
||||
use_remove_padding (bool): Use remove-padding optimization (saves compute).
|
||||
enable_gradient_checkpointing (bool): Enable gradient checkpointing for memory efficiency.
|
||||
fsdp_config (FSDPEngineConfig): FSDP-specific configuration block.
|
||||
lora_rank (int): Set to positive value to enable LoRA (e.g., 32).
|
||||
lora_alpha (int): LoRA scaling factor.
|
||||
target_modules (Union[str, List[str]]): LoRA target modules: "all-linear" or list of layer names.
|
||||
"""
|
||||
|
||||
use_shm: bool = False
|
||||
enable_activation_offload: bool = False
|
||||
use_remove_padding: bool = False
|
||||
enable_gradient_checkpointing: bool = True
|
||||
fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||
lora_rank: int = 0
|
||||
lora_alpha: int = 16
|
||||
target_modules: str | list[str] = "all-linear"
|
105
verl/workers/config/engine.py
Normal file
105
verl/workers/config/engine.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from verl.base_config import BaseConfig
|
||||
|
||||
__all__ = ["FSDPEngineConfig", "McoreEngineConfig"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class McoreEngineConfig(BaseConfig):
|
||||
"""Configuration for Megatron parallelism.
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
param_offload (bool): Whether to offload parameters to CPU.
|
||||
grad_offload (bool): Whether to offload gradients to CPU.
|
||||
optimizer_offload (bool): Whether to offload optimizer states to CPU.
|
||||
tensor_model_parallel_size (int): Tensor model parallel size.
|
||||
expert_model_parallel_size (int): Expert model parallel size for MoE models.
|
||||
expert_tensor_parallel_size (Optional[int]): Expert tensor parallel size for MoE models.
|
||||
pipeline_model_parallel_size (int): Pipeline model parallel size.
|
||||
virtual_pipeline_model_parallel_size (Optional[int]): Virtual pipeline model parallel size
|
||||
for interleaved scheduling.
|
||||
context_parallel_size (int): Context parallel size for long sequences.
|
||||
sequence_parallel (bool): Whether to enable sequence parallelism.
|
||||
use_distributed_optimizer (bool): Whether to use distributed optimizer.
|
||||
use_dist_checkpointing (bool): Whether to use distributed checkpointing.
|
||||
dist_checkpointing_path (Optional[str]): Path for distributed checkpointing.
|
||||
seed (int): Random seed for reproducibility.
|
||||
override_ddp_config (dict[str, Any]): Override configuration for DDP.
|
||||
override_transformer_config (dict[str, Any]): Override configuration for transformer.
|
||||
use_mbridge (bool): Whether to use MBridge for communication.
|
||||
"""
|
||||
|
||||
# sequence_parallel is not listed as a frozen field for auto-correction purpose
|
||||
_mutable_fields = BaseConfig._mutable_fields | {"sequence_parallel"}
|
||||
|
||||
param_offload: bool = False
|
||||
grad_offload: bool = False
|
||||
optimizer_offload: bool = False
|
||||
tensor_model_parallel_size: int = 1
|
||||
expert_model_parallel_size: int = 1
|
||||
expert_tensor_parallel_size: Optional[int] = None
|
||||
pipeline_model_parallel_size: int = 1
|
||||
virtual_pipeline_model_parallel_size: Optional[int] = None
|
||||
context_parallel_size: int = 1
|
||||
sequence_parallel: bool = True
|
||||
use_distributed_optimizer: bool = True
|
||||
use_dist_checkpointing: bool = False
|
||||
dist_checkpointing_path: Optional[str] = None
|
||||
seed: int = 42
|
||||
override_ddp_config: dict[str, Any] = field(default_factory=dict)
|
||||
override_transformer_config: dict[str, Any] = field(default_factory=dict)
|
||||
use_mbridge: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""config validation logics go here"""
|
||||
if self.tensor_model_parallel_size == 1:
|
||||
warnings.warn("set sequence parallel to false as TP size is 1", stacklevel=2)
|
||||
self.sequence_parallel = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPEngineConfig(BaseConfig):
|
||||
"""Configuration for FSDP (Fully Sharded Data Parallel).
|
||||
|
||||
The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
|
||||
|
||||
Args:
|
||||
wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy.
|
||||
param_offload (bool): Whether to offload parameters to CPU, default False
|
||||
optimizer_offload (bool): Whether to offload optimizer states to CPU, default False
|
||||
offload_policy (bool): Whether to offload policy model parameters, default False
|
||||
reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True
|
||||
fsdp_size (int): FSDP group size. -1 means use all available GPUs.
|
||||
forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False
|
||||
model_dtype (str): Model data type used to initialize the transformers model. default "fp32"
|
||||
use_orig_params (bool): Whether to use original parameters when initialize FSDP1, default False
|
||||
"""
|
||||
|
||||
wrap_policy: dict[str, Any] = field(default_factory=dict)
|
||||
param_offload: bool = False
|
||||
optimizer_offload: bool = False
|
||||
offload_policy: bool = False
|
||||
reshard_after_forward: bool = True
|
||||
fsdp_size: int = -1
|
||||
forward_prefetch: bool = False
|
||||
model_dtype: str = "fp32"
|
||||
use_orig_params: bool = False
|
94
verl/workers/config/optimizer.py
Normal file
94
verl/workers/config/optimizer.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import MISSING
|
||||
|
||||
from verl.base_config import BaseConfig
|
||||
|
||||
__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(BaseConfig):
|
||||
"""Base optimizer configuration.
|
||||
|
||||
Args:
|
||||
lr (float): learning rate. Must be specified.
|
||||
lr_warmup_steps_ratio (float): Warmup steps ratio; total steps will be injected at runtime.
|
||||
total_training_steps (int): Total training steps (must be overridden at runtime).
|
||||
weight_decay (float): Weight decay factor.
|
||||
lr_warmup_steps (Optional[int]): Number of warmup steps; None delegates to lr_warmup_steps_ratio.
|
||||
"""
|
||||
|
||||
lr: float = MISSING
|
||||
lr_warmup_steps_ratio: float = 0.0
|
||||
total_training_steps: int = -1
|
||||
weight_decay: float = 0.01
|
||||
lr_warmup_steps: Optional[int] = -1
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.lr != MISSING
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPOptimizerConfig(OptimizerConfig):
|
||||
"""FSDP optimizer configuration extending base OptimizerConfig.
|
||||
|
||||
Args:
|
||||
lr (float): Learning rate.
|
||||
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
|
||||
warmup_style (str): LR warmup style: "constant" or "cosine".
|
||||
num_cycles (float): Number of cosine cycles in LR schedule.
|
||||
"""
|
||||
|
||||
min_lr_ratio: Optional[float] = None
|
||||
warmup_style: str = "constant"
|
||||
num_cycles: float = 0.5
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.warmup_style in ["constant", "cosine"]
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class McoreOptimizerConfig(OptimizerConfig):
|
||||
"""Mcore optimizer configuration extending base OptimizerConfig.
|
||||
|
||||
Args:
|
||||
optimizer (str): Optimizer name; default is "adam".
|
||||
lr (float): Learning rate.
|
||||
clip_grad (float): Gradient clipping norm.
|
||||
lr_warmup_init (float): Initial learning rate for warmup; defaults to 0.0.
|
||||
lr_decay_steps (Optional[int]): Number of decay steps.
|
||||
lr_decay_style (str): LR decay style: "constant", "linear", "cosine", or "inverse_square_root".
|
||||
min_lr (float): Minimum learning rate.
|
||||
weight_decay_incr_style (str): Weight decay increment style: "constant" or "cosine".
|
||||
lr_wsd_decay_style (str): Weight-standard-deviation decay style: "constant", "exponential", or "cosine".
|
||||
lr_wsd_decay_steps (Optional[int]): Number of steps for weight-standard-deviation decay.
|
||||
use_checkpoint_opt_param_scheduler (bool): Whether to use checkpoint optimizer parameter scheduler.
|
||||
"""
|
||||
|
||||
optimizer: str = "adam"
|
||||
clip_grad: float = 1.0
|
||||
lr_warmup_init: float = 0.0
|
||||
lr_decay_steps: Optional[int] = None
|
||||
lr_decay_style: str = "linear"
|
||||
min_lr: float = 0.0
|
||||
weight_decay_incr_style: str = "constant"
|
||||
lr_wsd_decay_style: str = "exponential"
|
||||
lr_wsd_decay_steps: Optional[int] = None
|
||||
use_checkpoint_opt_param_scheduler: bool = False
|
@ -84,9 +84,6 @@ class MegatronPPOCritic(BasePPOCritic):
|
||||
assert config.get("ulysses_sequence_parallel_size", 1) == 1
|
||||
if config.shuffle:
|
||||
assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set"
|
||||
if config.megatron.tensor_model_parallel_size == 1:
|
||||
print("[Warining] Because critic tp size == 1, set sp to False")
|
||||
config.megatron.sequence_parallel = False
|
||||
self.config = config
|
||||
|
||||
@GPUMemoryLogger("megatron critic", logger=logger)
|
||||
|
@ -195,7 +195,7 @@ class FSDPEngine(BaseEngine):
|
||||
else:
|
||||
self.tokenizer.chat_template = self.config.model.custom_chat_template
|
||||
|
||||
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
override_config_kwargs = {
|
||||
"bos_token_id": self.tokenizer.bos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
|
@ -72,6 +72,7 @@ from verl.utils.model import compute_position_id_with_mask
|
||||
from verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage, simple_timer
|
||||
from verl.utils.profiler.performance import reduce_timing
|
||||
from verl.utils.py_functional import convert_to_regular_types
|
||||
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig
|
||||
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
@ -209,7 +210,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
def _build_model_optimizer(
|
||||
self,
|
||||
model_path,
|
||||
fsdp_config,
|
||||
fsdp_config: FSDPEngineConfig,
|
||||
optim_config,
|
||||
override_model_config,
|
||||
use_remove_padding=False,
|
||||
@ -378,8 +379,8 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
mixed_precision=mixed_precision,
|
||||
sync_module_states=True,
|
||||
device_mesh=self.device_mesh,
|
||||
use_orig_params=self.config.actor.fsdp_config.get("use_orig_params", False),
|
||||
forward_prefetch=self.config.actor.fsdp_config.get("forward_prefetch", False),
|
||||
use_orig_params=fsdp_config.get("use_orig_params", False),
|
||||
forward_prefetch=fsdp_config.get("forward_prefetch", False),
|
||||
)
|
||||
elif fsdp_strategy == "fsdp2":
|
||||
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
||||
@ -566,8 +567,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import_external_libs(self.config.model.get("external_lib", None))
|
||||
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
use_remove_padding = self.config.model.get("use_remove_padding", False)
|
||||
use_shm = self.config.model.get("use_shm", False)
|
||||
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
|
||||
@ -576,10 +576,10 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
# we need the model for actor and rollout
|
||||
if self._is_actor:
|
||||
optim_config = self.config.actor.optim
|
||||
fsdp_config = self.config.actor.fsdp_config
|
||||
fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config)
|
||||
else:
|
||||
optim_config = None
|
||||
fsdp_config = OmegaConf.create()
|
||||
fsdp_config = FSDPEngineConfig()
|
||||
|
||||
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
|
||||
(
|
||||
@ -614,12 +614,9 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
|
||||
|
||||
if self._is_actor:
|
||||
OmegaConf.set_struct(self.config.actor, True)
|
||||
with open_dict(self.config.actor):
|
||||
self.config.actor.use_remove_padding = use_remove_padding
|
||||
self.config.actor.use_fused_kernels = use_fused_kernels
|
||||
actor_cfg = omega_conf_to_dataclass(self.config.actor)
|
||||
self.actor = DataParallelPPOActor(
|
||||
config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
|
||||
config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
|
||||
)
|
||||
|
||||
if self._is_rollout:
|
||||
@ -631,7 +628,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
|
||||
self.ref_module_fsdp = self._build_model_optimizer(
|
||||
model_path=local_path,
|
||||
fsdp_config=self.config.ref.fsdp_config,
|
||||
fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config),
|
||||
optim_config=None,
|
||||
override_model_config=override_model_config,
|
||||
use_remove_padding=use_remove_padding,
|
||||
@ -916,18 +913,16 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
|
||||
|
||||
class CriticWorker(Worker, DistProfilerExtension):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: FSDPCriticConfig):
|
||||
Worker.__init__(self)
|
||||
DistProfilerExtension.__init__(
|
||||
self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")))
|
||||
)
|
||||
DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler")))
|
||||
import torch.distributed
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None)
|
||||
)
|
||||
self.config = config
|
||||
self.config: FSDPCriticConfig = config
|
||||
|
||||
# build device mesh for Ulysses Sequence Parallel
|
||||
world_size = torch.distributed.get_world_size()
|
||||
@ -996,8 +991,7 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
self.processor.chat_template = self.config.model.custom_chat_template
|
||||
else:
|
||||
self.tokenizer.chat_template = self.config.model.custom_chat_template
|
||||
|
||||
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
override_config_kwargs = {
|
||||
"bos_token_id": self.tokenizer.bos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
@ -1163,8 +1157,14 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
|
||||
)
|
||||
elif warmup_style == "cosine":
|
||||
min_lr_ratio = config.optim.get("min_lr_ratio", 0.0)
|
||||
num_cycles = config.optim.get("num_cycles", 0.5)
|
||||
critic_lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps
|
||||
optimizer=critic_optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_steps,
|
||||
min_lr_ratio=min_lr_ratio,
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
|
@ -26,7 +26,7 @@ import torch
|
||||
import torch.distributed
|
||||
from codetiming import Timer
|
||||
from megatron.core import parallel_state as mpu
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from verl import DataProto
|
||||
from verl.single_controller.base.decorator import Dispatch, register
|
||||
@ -53,6 +53,7 @@ from verl.utils.profiler import (
|
||||
)
|
||||
from verl.utils.profiler.performance import reduce_timing
|
||||
from verl.workers.actor.megatron_actor import MegatronPPOActor
|
||||
from verl.workers.config import McoreCriticConfig
|
||||
from verl.workers.critic.megatron_critic import MegatronPPOCritic
|
||||
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel
|
||||
|
||||
@ -202,7 +203,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
|
||||
return parallel_model
|
||||
|
||||
override_ddp_config = OmegaConf.to_container(
|
||||
self.config.actor.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
|
||||
OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {}))
|
||||
)
|
||||
return get_model(
|
||||
megatron_actor_model_provider,
|
||||
@ -390,14 +391,14 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
|
||||
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
if self._is_actor:
|
||||
override_transformer_config = OmegaConf.to_container(
|
||||
self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
|
||||
OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {}))
|
||||
)
|
||||
elif self._is_ref:
|
||||
override_transformer_config = OmegaConf.to_container(
|
||||
self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
|
||||
OmegaConf.create(self.config.ref.megatron.get("override_transformer_config", {}))
|
||||
)
|
||||
else:
|
||||
override_transformer_config = {}
|
||||
@ -427,12 +428,9 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
|
||||
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
|
||||
|
||||
if self._is_actor:
|
||||
OmegaConf.set_struct(self.config.actor, True)
|
||||
with open_dict(self.config.actor):
|
||||
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
|
||||
self.config.actor.use_fused_kernels = use_fused_kernels
|
||||
actor_cfg = omega_conf_to_dataclass(self.config.actor)
|
||||
self.actor = MegatronPPOActor(
|
||||
config=self.config.actor,
|
||||
config=actor_cfg,
|
||||
model_config=self.actor_model_config,
|
||||
hf_config=self.hf_config,
|
||||
tf_config=self.tf_config,
|
||||
@ -712,12 +710,10 @@ class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
|
||||
|
||||
|
||||
class CriticWorker(MegatronWorker, DistProfilerExtension):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: McoreCriticConfig):
|
||||
MegatronWorker.__init__(self)
|
||||
DistProfilerExtension.__init__(
|
||||
self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")))
|
||||
)
|
||||
self.config = config
|
||||
DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler")))
|
||||
self.config: McoreCriticConfig = config
|
||||
|
||||
# NOTE(sgm): We utilize colocate WorkerGroup by default.
|
||||
# As a result, Workers for different model share the same process.
|
||||
@ -807,7 +803,7 @@ class CriticWorker(MegatronWorker, DistProfilerExtension):
|
||||
return parallel_model
|
||||
|
||||
override_ddp_config = OmegaConf.to_container(
|
||||
self.config.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
|
||||
OmegaConf.create(self.config.megatron.get("override_ddp_config", {}))
|
||||
)
|
||||
# Step 3: initialize the megatron model
|
||||
critic_module = get_model(
|
||||
@ -861,9 +857,9 @@ class CriticWorker(MegatronWorker, DistProfilerExtension):
|
||||
import importlib
|
||||
|
||||
importlib.import_module(self.config.model.external_lib)
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
override_transformer_config = OmegaConf.to_container(
|
||||
self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
|
||||
OmegaConf.create(self.config.megatron.get("override_transformer_config", {}))
|
||||
)
|
||||
self.param_dtype = torch.bfloat16
|
||||
self.dtype = PrecisionType.to_dtype(self.param_dtype)
|
||||
@ -1108,9 +1104,9 @@ class RewardModelWorker(MegatronWorker, DistProfilerExtension):
|
||||
import importlib
|
||||
|
||||
importlib.import_module(self.config.model.external_lib)
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
|
||||
override_transformer_config = OmegaConf.to_container(
|
||||
self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
|
||||
OmegaConf.create(self.config.megatron.get("override_transformer_config", {}))
|
||||
)
|
||||
|
||||
use_shm = self.config.model.get("use_shm", False)
|
||||
|
Reference in New Issue
Block a user