[cfg] refactor: support +extra.any_key usage for the base dataclass config in verl (#2502)

### What does this PR do?

This PR makes update to the base config in verl:
- support +extra.any_key usage for the base config in verl.
- allow selective subfields to be frozen
- add a auto-generated config yaml file
`verl/trainer/config/_generated_ppo_trainer.yaml` for reference purpose,
in case the nested inheritance structure makes the config information
too scattered

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

- added frozen field tests

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

Now you can pass `--xx.profiler.extra.any_new_key=any_plain_value` in
command line to a dataclass inheriting `verl.BaseConfig`. This way we
can still pass dataclass configs inside verl but allow some flexiblity
in accepting new keys from users' adhoc usage.


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: Lin <haibin@Lins-Laptop.hsd1.wa.comcast.net>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
H
2025-07-14 18:06:56 -07:00
committed by GitHub
parent def5b28e3d
commit d0c7bbbc05
13 changed files with 556 additions and 8 deletions

View File

@ -66,6 +66,9 @@ jobs:
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install the current repository
run: |
pip install -e .
- name: Set ruff --output-format=github - name: Set ruff --output-format=github
run: | run: |
sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml

View File

@ -25,6 +25,9 @@ jobs:
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install the current repository
run: |
pip install -e .
- name: Set ruff --output-format=github - name: Set ruff --output-format=github
run: | run: |
sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml

View File

@ -6,3 +6,27 @@ repos:
args: ["--fix", "--show-fixes", "--output-format=full"] args: ["--fix", "--show-fixes", "--output-format=full"]
exclude: ^.*\.(ipynb)$ exclude: ^.*\.(ipynb)$
- id: ruff-format - id: ruff-format
- repo: local
hooks:
- id: autogen-trainer-cfg
name: Generate and verify verl/trainer/config/_generated_*.yaml
entry: scripts/generate_trainer_config.sh
language: script
pass_filenames: false
- repo: local
hooks:
- id: check-docstrings
name: Check doc string coverage
entry: python3 tests/special_sanity/check_docstrings.py
language: python
pass_filenames: false
- repo: local
hooks:
- id: check-license
name: Check license
entry: python3 tests/special_sanity/check_license.py --directory .
language: python
pass_filenames: false

View File

@ -31,7 +31,11 @@ pre-commit install
# for staged changes # for staged changes
pre-commit run pre-commit run
# for all files in the repo # for all files in the repo
# pre-commit run --all-files pre-commit run --all-files
# run a specific hook with pre-commit
# pre-commit run --all-files --show-diff-on-failure --color=always <hood-id>
pre-commit run --all-files --show-diff-on-failure --color=always ruff
pre-commit run --all-files --show-diff-on-failure --color=always autogen-trainer-cfg
``` ```
## Testing ## Testing

View File

@ -0,0 +1,28 @@
#!/usr/bin/env bash
set -euox pipefail
# 1. Dump the full config to a temp file
target_cfg=verl/trainer/config/_generated_ppo_trainer.yaml
tmp_header=$(mktemp)
tmp_cfg=$(mktemp)
echo "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'" > "$tmp_header"
echo "# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file." >> "$tmp_header"
echo "# Do not modify this file directly." >> "$tmp_header"
echo "# The file is usually only for reference and never used." >> "$tmp_header"
echo "" >> "$tmp_header"
python3 scripts/print_cfg.py --cfg job > "$tmp_cfg"
# 2. Extract from the line starting with "actor_rollout_ref" onward
cat $tmp_header > $target_cfg
sed -n '/^actor_rollout_ref/,$p' "$tmp_cfg" >> $target_cfg
# 3. Clean up
rm "$tmp_cfg" "$tmp_header"
# 4. Verify that verl/trainer/config/_generated_ppo_trainer.yaml wasn't changed on disk
if ! git diff --exit-code -- "$target_cfg" >/dev/null; then
echo "$target_cfg is out of date. Please regenerate via 'scripts/generate_trainer_config.sh' and commit the changes."
exit 1
fi
echo "All good"
exit 0

35
scripts/print_cfg.py Normal file
View File

@ -0,0 +1,35 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import hydra
except ImportError as e:
raise ImportError("Please install hydra-core via 'pip install hydra-core' and retry.") from e
@hydra.main(config_path="../verl/trainer/config", config_name="ppo_trainer", version_base=None)
def main(config):
"""Main entry point for PPO training with Hydra configuration management.
Args:
config_dict: Hydra configuration dictionary containing training parameters.
"""
print(config)
from verl.utils.config import omega_conf_to_dataclass
profiler_config = omega_conf_to_dataclass(config.critic.profiler)
print(profiler_config)
if __name__ == "__main__":
main()

View File

@ -67,5 +67,29 @@ class TestConfigOnCPU(unittest.TestCase):
assert isinstance(cfg.model, TestDataclass) assert isinstance(cfg.model, TestDataclass)
class TestPrintCfgCommand(unittest.TestCase):
"""Test suite for the print_cfg.py command-line tool."""
def test_command_with_override(self):
"""Test that the command runs without error when overriding config values."""
import subprocess
# Run the command
result = subprocess.run(
["python3", "scripts/print_cfg.py", "critic.profiler.discrete=True", "+critic.profiler.extra.any_key=val"],
capture_output=True,
text=True,
)
# Verify the command exited successfully
self.assertEqual(result.returncode, 0, f"Command failed with stderr: {result.stderr}")
# Verify the output contains expected config information
self.assertIn("critic", result.stdout)
self.assertIn("profiler", result.stdout)
self.assertIn("discrete=True", result.stdout)
self.assertIn("extra={'any_key': 'val'}", result.stdout)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -50,6 +50,37 @@ class TestProfilerConfig(unittest.TestCase):
with self.assertRaises(FrozenInstanceError): with self.assertRaises(FrozenInstanceError):
profiler_config.discrete = False profiler_config.discrete = False
def test_frozen_config(self):
"""Test that modifying frozen keys in ProfilerConfig raises exceptions."""
from dataclasses import FrozenInstanceError
from verl.utils.profiler.config import ProfilerConfig
# Create a new ProfilerConfig instance
config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0])
# Test direct attribute assignment
with self.assertRaises(FrozenInstanceError):
config.discrete = False
with self.assertRaises(FrozenInstanceError):
config.all_ranks = True
with self.assertRaises(FrozenInstanceError):
config.ranks = [1, 2, 3]
# Test dictionary-style assignment
with self.assertRaises(TypeError):
config["discrete"] = False
with self.assertRaises(TypeError):
config["all_ranks"] = True
with self.assertRaises(TypeError):
config["ranks"] = [1, 2, 3]
config["extra"]["key"] = "value"
class TestNsightSystemsProfiler(unittest.TestCase): class TestNsightSystemsProfiler(unittest.TestCase):
"""Test suite for NsightSystemsProfiler functionality. """Test suite for NsightSystemsProfiler functionality.

View File

@ -13,11 +13,16 @@
# limitations under the License. # limitations under the License.
import collections import collections
from dataclasses import fields # Import the fields function to inspect dataclass fields from dataclasses import (
dataclass,
field,
fields, # Import the fields function to inspect dataclass fields
)
from typing import Any from typing import Any
# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary # BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary
@dataclass
class BaseConfig(collections.abc.Mapping): class BaseConfig(collections.abc.Mapping):
"""The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config. """The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config.
@ -25,6 +30,18 @@ class BaseConfig(collections.abc.Mapping):
This allows instances of this class to be used like dictionaries. This allows instances of this class to be used like dictionaries.
""" """
extra: dict[str, Any] = field(default_factory=dict)
def __setattr__(self, name: str, value):
# if the field already exists (i.e. was set in __init__)
# and is in our frozen list, block assignment
if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__:
from dataclasses import FrozenInstanceError
raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified")
# otherwise do the normal thing
super().__setattr__(name, value)
def get(self, key: str, default: Any = None) -> Any: def get(self, key: str, default: Any = None) -> Any:
"""Get the value associated with the given key. If the key does not exist, return the default value. """Get the value associated with the given key. If the key does not exist, return the default value.

View File

@ -0,0 +1,361 @@
# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'
# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file.
# Do not modify this file directly.
# The file is usually only for reference and never used.
actor_rollout_ref:
actor:
strategy: fsdp
ppo_mini_batch_size: 256
ppo_micro_batch_size: null
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: false
ppo_max_token_len_per_gpu: 16384
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
policy_loss:
loss_mode: vanilla
clip_cov_ratio: 0.0002
clip_cov_lb: 1.0
clip_cov_ub: 5.0
kl_cov_ratio: 0.0002
ppo_kl_coef: 0.1
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
kl_loss_type: low_var_kl
ppo_epochs: 1
shuffle: false
checkpoint:
save_contents:
- model
- optimizer
- extra
load_contents: ${.save_contents}
optim:
lr: 1.0e-06
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr_warmup_steps: -1
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
grad_clip: 1.0
ulysses_sequence_parallel_size: 1
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: false
optimizer_offload: false
offload_policy: false
reshard_after_forward: true
fsdp_size: -1
forward_prefetch: false
ref:
strategy: ${actor_rollout_ref.actor.strategy}
use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
log_prob_micro_batch_size: null
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
fsdp_config:
param_offload: false
reshard_after_forward: true
forward_prefetch: false
wrap_policy:
min_num_params: 0
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
rollout:
name: vllm
mode: sync
temperature: 1.0
top_k: -1
top_p: 1
prompt_length: ${oc.select:data.max_prompt_length,512}
response_length: ${oc.select:data.max_response_length,512}
dtype: bfloat16
gpu_memory_utilization: 0.5
ignore_eos: false
enforce_eager: true
free_cache_engine: true
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
disable_log_stats: true
do_sample: true
'n': 1
multi_stage_wake_up: false
engine_kwargs:
vllm:
swap_space: null
disable_mm_preprocessor_cache: false
sglang:
attention_backend: null
val_kwargs:
top_k: -1
top_p: 1.0
temperature: 0
'n': 1
do_sample: false
multi_turn:
enable: false
max_assistant_turns: null
tool_config_path: null
max_user_turns: null
max_parallel_calls: 1
max_tool_response_length: 256
tool_response_truncate_side: middle
interaction_config_path: null
completion_callback: null
use_inference_chat_template: false
tokenization_sanity_check_mode: strict
format: hermes
calculate_log_probs: false
agent:
num_workers: 8
custom_async_server:
path: null
name: null
enable_chunked_prefill: true
load_format: dummy_dtensor
layered_summon: false
hybrid_engine: true
model:
path: ~/models/deepseek-llm-7b-chat
custom_chat_template: null
use_shm: false
external_lib: null
override_config: {}
enable_gradient_checkpointing: true
enable_activation_offload: false
use_remove_padding: false
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
exclude_modules: null
use_liger: false
use_fused_kernels: false
fused_kernel_options:
impl_backend: torch
trust_remote_code: false
profiler:
_target_: verl.utils.profiler.ProfilerConfig
discrete: false
all_ranks: false
ranks: []
trainer:
npu_profile:
options:
save_path: ./profiler_data
level: level1
with_memory: false
record_shapes: false
with_npu: true
with_cpu: true
with_module: false
with_stack: false
analysis: true
balance_batch: true
total_epochs: 30
total_training_steps: null
profile_steps: null
controller_nsight_options:
trace: cuda,nvtx,cublas,ucx
cuda-memory-usage: 'true'
cuda-graph-trace: graph
worker_nsight_options:
trace: cuda,nvtx,cublas,ucx
cuda-memory-usage: 'true'
cuda-graph-trace: graph
capture-range: cudaProfilerApi
capture-range-end: null
kill: none
project_name: verl_examples
experiment_name: gsm8k
logger:
- console
- wandb
log_val_generations: 0
rollout_data_dir: null
validation_data_dir: null
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
esi_redundant_time: 0
resume_mode: auto
resume_from_path: null
val_before_train: true
val_only: false
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
del_local_ckpt_after_load: false
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
max_actor_ckpt_to_keep: null
max_critic_ckpt_to_keep: null
ray_wait_register_center_timeout: 300
device: cuda
data:
tokenizer: null
use_shm: false
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: null
return_raw_input_ids: false
return_raw_chat: false
return_full_prompt: false
shuffle: true
dataloader_num_workers: 8
validation_shuffle: false
filter_overlong_prompts: false
filter_overlong_prompts_workers: 1
truncation: error
image_key: images
video_key: videos
trust_remote_code: false
custom_cls:
path: null
name: null
return_multi_modal_inputs: true
sampler:
class_path: null
class_name: null
datagen:
path: null
name: null
critic:
rollout_n: ${actor_rollout_ref.rollout.n}
strategy: fsdp
optim:
lr_warmup_steps_ratio: 0.0
total_training_steps: -1
weight_decay: 0.01
lr: 1.0e-05
min_lr_ratio: null
warmup_style: constant
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: true
trust_remote_code: ${actor_rollout_ref.model.trust_remote_code}
use_shm: false
enable_activation_offload: false
use_remove_padding: false
fsdp_config:
param_offload: false
optimizer_offload: false
offload_policy: false
reshard_after_forward: true
wrap_policy:
min_num_params: 0
fsdp_size: -1
forward_prefetch: false
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768
forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
cliprange_value: 0.5
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
checkpoint:
save_contents:
- model
- optimizer
- extra
load_contents: ${.save_contents}
profiler:
_target_: verl.utils.profiler.ProfilerConfig
discrete: false
all_ranks: false
ranks: []
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
ulysses_sequence_parallel_size: 1
grad_clip: 1.0
reward_model:
enable: false
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path}
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
trust_remote_code: false
use_shm: false
use_remove_padding: false
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: false
reshard_after_forward: true
fsdp_size: -1
forward_prefetch: false
micro_batch_size: null
micro_batch_size_per_gpu: null
max_length: null
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive
launch_reward_fn_async: false
sandbox_fusion:
url: null
max_concurrent: 64
memory_limit_mb: 1024
profiler:
_target_: verl.utils.profiler.ProfilerConfig
discrete: false
all_ranks: false
ranks: []
ulysses_sequence_parallel_size: 1
custom_reward_function:
path: null
name: compute_score
algorithm:
_target_: verl.trainer.config.AlgoConfig
gamma: 1.0
lam: 1.0
adv_estimator: gae
norm_adv_by_std_in_grpo: true
use_kl_in_reward: false
kl_penalty: kl
kl_ctrl:
_target_: verl.trainer.config.KLControlConfig
type: fixed
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
use_pf_ppo: false
pf_ppo:
_target_: verl.trainer.config.PFPPOConfig
reweight_method: pow
weight_pow: 2.0
ray_init:
num_cpus: null
timeline_json_file: null

View File

@ -18,7 +18,7 @@ from typing import Optional
from verl.base_config import BaseConfig from verl.base_config import BaseConfig
@dataclass(frozen=True) @dataclass
class KLControlConfig(BaseConfig): class KLControlConfig(BaseConfig):
"""Configuration for KL control. """Configuration for KL control.
@ -31,13 +31,14 @@ class KLControlConfig(BaseConfig):
target_kl (float): Target KL divergence for adaptive controller. target_kl (float): Target KL divergence for adaptive controller.
""" """
_frozen_fields = ["type", "kl_coef", "horizon", "target_kl"]
type: str = "fixed" type: str = "fixed"
kl_coef: float = 0.001 kl_coef: float = 0.001
horizon: int = 10000 horizon: int = 10000
target_kl: float = 0.1 target_kl: float = 0.1
@dataclass(frozen=True) @dataclass
class PFPPOConfig(BaseConfig): class PFPPOConfig(BaseConfig):
"""Configuration for preference feedback PPO. """Configuration for preference feedback PPO.
@ -48,11 +49,12 @@ class PFPPOConfig(BaseConfig):
weight_pow (float): Power used for weight scaling in "pow" method. weight_pow (float): Power used for weight scaling in "pow" method.
""" """
_frozen_fields = ["reweight_method", "weight_pow"]
reweight_method: str = "pow" reweight_method: str = "pow"
weight_pow: float = 2.0 weight_pow: float = 2.0
@dataclass(frozen=True) @dataclass
class FilterGroupsConfig(BaseConfig): class FilterGroupsConfig(BaseConfig):
"""Configuration for filter groups (used in DAPO and Entropy). """Configuration for filter groups (used in DAPO and Entropy).
@ -64,12 +66,14 @@ class FilterGroupsConfig(BaseConfig):
max_num_gen_batches (int): Non-positive values mean no upper limit. max_num_gen_batches (int): Non-positive values mean no upper limit.
""" """
_frozen_fields = ["enable", "metric", "max_num_gen_batches"]
enable: bool = False enable: bool = False
metric: Optional[str] = None metric: Optional[str] = None
max_num_gen_batches: int = 0 max_num_gen_batches: int = 0
@dataclass(frozen=True) @dataclass
class AlgoConfig(BaseConfig): class AlgoConfig(BaseConfig):
"""Configuration for the algorithm. """Configuration for the algorithm.
@ -88,6 +92,16 @@ class AlgoConfig(BaseConfig):
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
""" """
_frozen_fields = [
"gamma",
"lam",
"adv_estimator",
"norm_adv_by_std_in_grpo",
"use_kl_in_reward",
"kl_penalty",
"use_pf_ppo",
]
gamma: float = 1.0 gamma: float = 1.0
lam: float = 1.0 lam: float = 1.0
adv_estimator: str = "gae" adv_estimator: str = "gae"

View File

@ -1,4 +1,4 @@
# actor_rollout_ref.rollout.name: hf/vllm/sglang. # actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future
name: vllm name: vllm
# sync: LLM, async: AsyncLLM # sync: LLM, async: AsyncLLM

View File

@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ClassVar
from verl.base_config import BaseConfig from verl.base_config import BaseConfig
@dataclass(frozen=True) @dataclass
class ProfilerConfig(BaseConfig): class ProfilerConfig(BaseConfig):
"""Worker profiler config. Currently only support Nsight system profiler. """Worker profiler config. Currently only support Nsight system profiler.
@ -30,6 +31,9 @@ class ProfilerConfig(BaseConfig):
ranks (list[int]): The ranks that will be profiled. Defaults to []. ranks (list[int]): The ranks that will be profiled. Defaults to [].
""" """
# the fields expected to be frozen
_frozen_fields: ClassVar[set[str]] = {"discrete", "all_ranks", "ranks"}
discrete: bool = False discrete: bool = False
all_ranks: bool = False all_ranks: bool = False