[cfg] refactor: support +extra.any_key usage for the base dataclass config in verl (#2502)

### What does this PR do?

This PR makes update to the base config in verl:
- support +extra.any_key usage for the base config in verl.
- allow selective subfields to be frozen
- add a auto-generated config yaml file
`verl/trainer/config/_generated_ppo_trainer.yaml` for reference purpose,
in case the nested inheritance structure makes the config information
too scattered

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

- added frozen field tests

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

Now you can pass `--xx.profiler.extra.any_new_key=any_plain_value` in
command line to a dataclass inheriting `verl.BaseConfig`. This way we
can still pass dataclass configs inside verl but allow some flexiblity
in accepting new keys from users' adhoc usage.


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: Lin <haibin@Lins-Laptop.hsd1.wa.comcast.net>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
H
2025-07-14 18:06:56 -07:00
committed by GitHub
parent def5b28e3d
commit d0c7bbbc05
13 changed files with 556 additions and 8 deletions

View File

@ -66,6 +66,9 @@ jobs:
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install the current repository
run: |
pip install -e .
- name: Set ruff --output-format=github
run: |
sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml

View File

@ -25,6 +25,9 @@ jobs:
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install the current repository
run: |
pip install -e .
- name: Set ruff --output-format=github
run: |
sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml

View File

@ -6,3 +6,27 @@ repos:
args: ["--fix", "--show-fixes", "--output-format=full"]
exclude: ^.*\.(ipynb)$
- id: ruff-format
- repo: local
hooks:
- id: autogen-trainer-cfg
name: Generate and verify verl/trainer/config/_generated_*.yaml
entry: scripts/generate_trainer_config.sh
language: script
pass_filenames: false
- repo: local
hooks:
- id: check-docstrings
name: Check doc string coverage
entry: python3 tests/special_sanity/check_docstrings.py
language: python
pass_filenames: false
- repo: local
hooks:
- id: check-license
name: Check license
entry: python3 tests/special_sanity/check_license.py --directory .
language: python
pass_filenames: false

View File

@ -31,7 +31,11 @@ pre-commit install
# for staged changes
pre-commit run
# for all files in the repo
# pre-commit run --all-files
pre-commit run --all-files
# run a specific hook with pre-commit
# pre-commit run --all-files --show-diff-on-failure --color=always <hood-id>
pre-commit run --all-files --show-diff-on-failure --color=always ruff
pre-commit run --all-files --show-diff-on-failure --color=always autogen-trainer-cfg
```
## Testing

View File

@ -0,0 +1,28 @@
#!/usr/bin/env bash
set -euox pipefail
# 1. Dump the full config to a temp file
target_cfg=verl/trainer/config/_generated_ppo_trainer.yaml
tmp_header=$(mktemp)
tmp_cfg=$(mktemp)
echo "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'" > "$tmp_header"
echo "# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file." >> "$tmp_header"
echo "# Do not modify this file directly." >> "$tmp_header"
echo "# The file is usually only for reference and never used." >> "$tmp_header"
echo "" >> "$tmp_header"
python3 scripts/print_cfg.py --cfg job > "$tmp_cfg"
# 2. Extract from the line starting with "actor_rollout_ref" onward
cat $tmp_header > $target_cfg
sed -n '/^actor_rollout_ref/,$p' "$tmp_cfg" >> $target_cfg
# 3. Clean up
rm "$tmp_cfg" "$tmp_header"
# 4. Verify that verl/trainer/config/_generated_ppo_trainer.yaml wasn't changed on disk
if ! git diff --exit-code -- "$target_cfg" >/dev/null; then
echo "$target_cfg is out of date. Please regenerate via 'scripts/generate_trainer_config.sh' and commit the changes."
exit 1
fi
echo "All good"
exit 0

35
scripts/print_cfg.py Normal file
View File

@ -0,0 +1,35 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import hydra
except ImportError as e:
raise ImportError("Please install hydra-core via 'pip install hydra-core' and retry.") from e
@hydra.main(config_path="../verl/trainer/config", config_name="ppo_trainer", version_base=None)
def main(config):
"""Main entry point for PPO training with Hydra configuration management.
Args:
config_dict: Hydra configuration dictionary containing training parameters.
"""
print(config)
from verl.utils.config import omega_conf_to_dataclass
profiler_config = omega_conf_to_dataclass(config.critic.profiler)
print(profiler_config)
if __name__ == "__main__":
main()

View File

@ -67,5 +67,29 @@ class TestConfigOnCPU(unittest.TestCase):
assert isinstance(cfg.model, TestDataclass)
class TestPrintCfgCommand(unittest.TestCase):
"""Test suite for the print_cfg.py command-line tool."""
def test_command_with_override(self):
"""Test that the command runs without error when overriding config values."""
import subprocess
# Run the command
result = subprocess.run(
["python3", "scripts/print_cfg.py", "critic.profiler.discrete=True", "+critic.profiler.extra.any_key=val"],
capture_output=True,
text=True,
)
# Verify the command exited successfully
self.assertEqual(result.returncode, 0, f"Command failed with stderr: {result.stderr}")
# Verify the output contains expected config information
self.assertIn("critic", result.stdout)
self.assertIn("profiler", result.stdout)
self.assertIn("discrete=True", result.stdout)
self.assertIn("extra={'any_key': 'val'}", result.stdout)
if __name__ == "__main__":
unittest.main()

View File

@ -50,6 +50,37 @@ class TestProfilerConfig(unittest.TestCase):
with self.assertRaises(FrozenInstanceError):
profiler_config.discrete = False
def test_frozen_config(self):
"""Test that modifying frozen keys in ProfilerConfig raises exceptions."""
from dataclasses import FrozenInstanceError
from verl.utils.profiler.config import ProfilerConfig
# Create a new ProfilerConfig instance
config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0])
# Test direct attribute assignment
with self.assertRaises(FrozenInstanceError):
config.discrete = False
with self.assertRaises(FrozenInstanceError):
config.all_ranks = True
with self.assertRaises(FrozenInstanceError):
config.ranks = [1, 2, 3]
# Test dictionary-style assignment
with self.assertRaises(TypeError):
config["discrete"] = False
with self.assertRaises(TypeError):
config["all_ranks"] = True
with self.assertRaises(TypeError):
config["ranks"] = [1, 2, 3]
config["extra"]["key"] = "value"
class TestNsightSystemsProfiler(unittest.TestCase):
"""Test suite for NsightSystemsProfiler functionality.

View File

@ -13,11 +13,16 @@
# limitations under the License.
import collections
from dataclasses import fields # Import the fields function to inspect dataclass fields
from dataclasses import (
dataclass,
field,
fields, # Import the fields function to inspect dataclass fields
)
from typing import Any
# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary
@dataclass
class BaseConfig(collections.abc.Mapping):
"""The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config.
@ -25,6 +30,18 @@ class BaseConfig(collections.abc.Mapping):
This allows instances of this class to be used like dictionaries.
"""
extra: dict[str, Any] = field(default_factory=dict)
def __setattr__(self, name: str, value):
# if the field already exists (i.e. was set in __init__)
# and is in our frozen list, block assignment
if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__:
from dataclasses import FrozenInstanceError
raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified")
# otherwise do the normal thing
super().__setattr__(name, value)
def get(self, key: str, default: Any = None) -> Any:
"""Get the value associated with the given key. If the key does not exist, return the default value.

View File

@ -0,0 +1,361 @@
# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'
# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file.
# Do not modify this file directly.
# The file is usually only for reference and never used.
actor_rollout_ref:
actor:
strategy: fsdp
ppo_mini_batch_size: 256
ppo_micro_batch_size: null
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: false
ppo_max_token_len_per_gpu: 16384
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
policy_loss:
loss_mode: vanilla
clip_cov_ratio: 0.0002
clip_cov_lb: 1.0
clip_cov_ub: 5.0
kl_cov_ratio: 0.0002
ppo_kl_coef: 0.1
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
kl_loss_type: low_var_kl
ppo_epochs: 1
shuffle: false
checkpoint:
save_contents:
- model
- optimizer
- extra
load_contents: ${.save_contents}
optim:
lr: 1.0e-06
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr_warmup_steps: -1
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
grad_clip: 1.0
ulysses_sequence_parallel_size: 1
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: false
optimizer_offload: false
offload_policy: false
reshard_after_forward: true
fsdp_size: -1
forward_prefetch: false
ref:
strategy: ${actor_rollout_ref.actor.strategy}
use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
log_prob_micro_batch_size: null
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
fsdp_config:
param_offload: false
reshard_after_forward: true
forward_prefetch: false
wrap_policy:
min_num_params: 0
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
rollout:
name: vllm
mode: sync
temperature: 1.0
top_k: -1
top_p: 1
prompt_length: ${oc.select:data.max_prompt_length,512}
response_length: ${oc.select:data.max_response_length,512}
dtype: bfloat16
gpu_memory_utilization: 0.5
ignore_eos: false
enforce_eager: true
free_cache_engine: true
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
disable_log_stats: true
do_sample: true
'n': 1
multi_stage_wake_up: false
engine_kwargs:
vllm:
swap_space: null
disable_mm_preprocessor_cache: false
sglang:
attention_backend: null
val_kwargs:
top_k: -1
top_p: 1.0
temperature: 0
'n': 1
do_sample: false
multi_turn:
enable: false
max_assistant_turns: null
tool_config_path: null
max_user_turns: null
max_parallel_calls: 1
max_tool_response_length: 256
tool_response_truncate_side: middle
interaction_config_path: null
completion_callback: null
use_inference_chat_template: false
tokenization_sanity_check_mode: strict
format: hermes
calculate_log_probs: false
agent:
num_workers: 8
custom_async_server:
path: null
name: null
enable_chunked_prefill: true
load_format: dummy_dtensor
layered_summon: false
hybrid_engine: true
model:
path: ~/models/deepseek-llm-7b-chat
custom_chat_template: null
use_shm: false
external_lib: null
override_config: {}
enable_gradient_checkpointing: true
enable_activation_offload: false
use_remove_padding: false
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
exclude_modules: null
use_liger: false
use_fused_kernels: false
fused_kernel_options:
impl_backend: torch
trust_remote_code: false
profiler:
_target_: verl.utils.profiler.ProfilerConfig
discrete: false
all_ranks: false
ranks: []
trainer:
npu_profile:
options:
save_path: ./profiler_data
level: level1
with_memory: false
record_shapes: false
with_npu: true
with_cpu: true
with_module: false
with_stack: false
analysis: true
balance_batch: true
total_epochs: 30
total_training_steps: null
profile_steps: null
controller_nsight_options:
trace: cuda,nvtx,cublas,ucx
cuda-memory-usage: 'true'
cuda-graph-trace: graph
worker_nsight_options:
trace: cuda,nvtx,cublas,ucx
cuda-memory-usage: 'true'
cuda-graph-trace: graph
capture-range: cudaProfilerApi
capture-range-end: null
kill: none
project_name: verl_examples
experiment_name: gsm8k
logger:
- console
- wandb
log_val_generations: 0
rollout_data_dir: null
validation_data_dir: null
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
esi_redundant_time: 0
resume_mode: auto
resume_from_path: null
val_before_train: true
val_only: false
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
del_local_ckpt_after_load: false
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
max_actor_ckpt_to_keep: null
max_critic_ckpt_to_keep: null
ray_wait_register_center_timeout: 300
device: cuda
data:
tokenizer: null
use_shm: false
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: null
return_raw_input_ids: false
return_raw_chat: false
return_full_prompt: false
shuffle: true
dataloader_num_workers: 8
validation_shuffle: false
filter_overlong_prompts: false
filter_overlong_prompts_workers: 1
truncation: error
image_key: images
video_key: videos
trust_remote_code: false
custom_cls:
path: null
name: null
return_multi_modal_inputs: true
sampler:
class_path: null
class_name: null
datagen:
path: null
name: null
critic:
rollout_n: ${actor_rollout_ref.rollout.n}
strategy: fsdp
optim:
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr: 1.0e-05
min_lr_ratio: null
warmup_style: constant
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: true
trust_remote_code: ${actor_rollout_ref.model.trust_remote_code}
use_shm: false
enable_activation_offload: false
use_remove_padding: false
fsdp_config:
param_offload: false
optimizer_offload: false
offload_policy: false
reshard_after_forward: true
wrap_policy:
min_num_params: 0
fsdp_size: -1
forward_prefetch: false
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768
forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
cliprange_value: 0.5
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
checkpoint:
save_contents:
- model
- optimizer
- extra
load_contents: ${.save_contents}
profiler:
_target_: verl.utils.profiler.ProfilerConfig
discrete: false
all_ranks: false
ranks: []
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
ulysses_sequence_parallel_size: 1
grad_clip: 1.0
reward_model:
enable: false
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path}
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
trust_remote_code: false
use_shm: false
use_remove_padding: false
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: false
reshard_after_forward: true
fsdp_size: -1
forward_prefetch: false
micro_batch_size: null
micro_batch_size_per_gpu: null
max_length: null
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive
launch_reward_fn_async: false
sandbox_fusion:
url: null
max_concurrent: 64
memory_limit_mb: 1024
profiler:
_target_: verl.utils.profiler.ProfilerConfig
discrete: false
all_ranks: false
ranks: []
ulysses_sequence_parallel_size: 1
custom_reward_function:
path: null
name: compute_score
algorithm:
_target_: verl.trainer.config.AlgoConfig
gamma: 1.0
lam: 1.0
adv_estimator: gae
norm_adv_by_std_in_grpo: true
use_kl_in_reward: false
kl_penalty: kl
kl_ctrl:
_target_: verl.trainer.config.KLControlConfig
type: fixed
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
use_pf_ppo: false
pf_ppo:
_target_: verl.trainer.config.PFPPOConfig
reweight_method: pow
weight_pow: 2.0
ray_init:
num_cpus: null
timeline_json_file: null

View File

@ -18,7 +18,7 @@ from typing import Optional
from verl.base_config import BaseConfig
@dataclass(frozen=True)
@dataclass
class KLControlConfig(BaseConfig):
"""Configuration for KL control.
@ -31,13 +31,14 @@ class KLControlConfig(BaseConfig):
target_kl (float): Target KL divergence for adaptive controller.
"""
_frozen_fields = ["type", "kl_coef", "horizon", "target_kl"]
type: str = "fixed"
kl_coef: float = 0.001
horizon: int = 10000
target_kl: float = 0.1
@dataclass(frozen=True)
@dataclass
class PFPPOConfig(BaseConfig):
"""Configuration for preference feedback PPO.
@ -48,11 +49,12 @@ class PFPPOConfig(BaseConfig):
weight_pow (float): Power used for weight scaling in "pow" method.
"""
_frozen_fields = ["reweight_method", "weight_pow"]
reweight_method: str = "pow"
weight_pow: float = 2.0
@dataclass(frozen=True)
@dataclass
class FilterGroupsConfig(BaseConfig):
"""Configuration for filter groups (used in DAPO and Entropy).
@ -64,12 +66,14 @@ class FilterGroupsConfig(BaseConfig):
max_num_gen_batches (int): Non-positive values mean no upper limit.
"""
_frozen_fields = ["enable", "metric", "max_num_gen_batches"]
enable: bool = False
metric: Optional[str] = None
max_num_gen_batches: int = 0
@dataclass(frozen=True)
@dataclass
class AlgoConfig(BaseConfig):
"""Configuration for the algorithm.
@ -88,6 +92,16 @@ class AlgoConfig(BaseConfig):
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
"""
_frozen_fields = [
"gamma",
"lam",
"adv_estimator",
"norm_adv_by_std_in_grpo",
"use_kl_in_reward",
"kl_penalty",
"use_pf_ppo",
]
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"

View File

@ -1,4 +1,4 @@
# actor_rollout_ref.rollout.name: hf/vllm/sglang.
# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future
name: vllm
# sync: LLM, async: AsyncLLM

View File

@ -13,11 +13,12 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import ClassVar
from verl.base_config import BaseConfig
@dataclass(frozen=True)
@dataclass
class ProfilerConfig(BaseConfig):
"""Worker profiler config. Currently only support Nsight system profiler.
@ -30,6 +31,9 @@ class ProfilerConfig(BaseConfig):
ranks (list[int]): The ranks that will be profiled. Defaults to [].
"""
# the fields expected to be frozen
_frozen_fields: ClassVar[set[str]] = {"discrete", "all_ranks", "ranks"}
discrete: bool = False
all_ranks: bool = False