diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 1419a27d9..c67dce301 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -66,6 +66,9 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} + - name: Install the current repository + run: | + pip install -e . - name: Set ruff --output-format=github run: | sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 21dce5c47..80cfa0945 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -25,6 +25,9 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} + - name: Install the current repository + run: | + pip install -e . - name: Set ruff --output-format=github run: | sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 76dfe8cb5..e68644a01 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,3 +6,27 @@ repos: args: ["--fix", "--show-fixes", "--output-format=full"] exclude: ^.*\.(ipynb)$ - id: ruff-format + + - repo: local + hooks: + - id: autogen-trainer-cfg + name: Generate and verify verl/trainer/config/_generated_*.yaml + entry: scripts/generate_trainer_config.sh + language: script + pass_filenames: false + + - repo: local + hooks: + - id: check-docstrings + name: Check doc string coverage + entry: python3 tests/special_sanity/check_docstrings.py + language: python + pass_filenames: false + + - repo: local + hooks: + - id: check-license + name: Check license + entry: python3 tests/special_sanity/check_license.py --directory . + language: python + pass_filenames: false \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7487e43ad..e953f113e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,7 +31,11 @@ pre-commit install # for staged changes pre-commit run # for all files in the repo -# pre-commit run --all-files +pre-commit run --all-files +# run a specific hook with pre-commit +# pre-commit run --all-files --show-diff-on-failure --color=always +pre-commit run --all-files --show-diff-on-failure --color=always ruff +pre-commit run --all-files --show-diff-on-failure --color=always autogen-trainer-cfg ``` ## Testing diff --git a/scripts/generate_trainer_config.sh b/scripts/generate_trainer_config.sh new file mode 100755 index 000000000..fc2386e53 --- /dev/null +++ b/scripts/generate_trainer_config.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +set -euox pipefail + +# 1. Dump the full config to a temp file +target_cfg=verl/trainer/config/_generated_ppo_trainer.yaml +tmp_header=$(mktemp) +tmp_cfg=$(mktemp) +echo "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'" > "$tmp_header" +echo "# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file." >> "$tmp_header" +echo "# Do not modify this file directly." >> "$tmp_header" +echo "# The file is usually only for reference and never used." >> "$tmp_header" +echo "" >> "$tmp_header" +python3 scripts/print_cfg.py --cfg job > "$tmp_cfg" + +# 2. Extract from the line starting with "actor_rollout_ref" onward +cat $tmp_header > $target_cfg +sed -n '/^actor_rollout_ref/,$p' "$tmp_cfg" >> $target_cfg + +# 3. Clean up +rm "$tmp_cfg" "$tmp_header" + +# 4. Verify that verl/trainer/config/_generated_ppo_trainer.yaml wasn't changed on disk +if ! git diff --exit-code -- "$target_cfg" >/dev/null; then + echo "✖ $target_cfg is out of date. Please regenerate via 'scripts/generate_trainer_config.sh' and commit the changes." + exit 1 +fi +echo "All good" +exit 0 \ No newline at end of file diff --git a/scripts/print_cfg.py b/scripts/print_cfg.py new file mode 100644 index 000000000..287756fb1 --- /dev/null +++ b/scripts/print_cfg.py @@ -0,0 +1,35 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import hydra +except ImportError as e: + raise ImportError("Please install hydra-core via 'pip install hydra-core' and retry.") from e + + +@hydra.main(config_path="../verl/trainer/config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + print(config) + from verl.utils.config import omega_conf_to_dataclass + + profiler_config = omega_conf_to_dataclass(config.critic.profiler) + print(profiler_config) + + +if __name__ == "__main__": + main() diff --git a/tests/utils/test_config_on_cpu.py b/tests/utils/test_config_on_cpu.py index 03d952c9f..42dc8e1f2 100644 --- a/tests/utils/test_config_on_cpu.py +++ b/tests/utils/test_config_on_cpu.py @@ -67,5 +67,29 @@ class TestConfigOnCPU(unittest.TestCase): assert isinstance(cfg.model, TestDataclass) +class TestPrintCfgCommand(unittest.TestCase): + """Test suite for the print_cfg.py command-line tool.""" + + def test_command_with_override(self): + """Test that the command runs without error when overriding config values.""" + import subprocess + + # Run the command + result = subprocess.run( + ["python3", "scripts/print_cfg.py", "critic.profiler.discrete=True", "+critic.profiler.extra.any_key=val"], + capture_output=True, + text=True, + ) + + # Verify the command exited successfully + self.assertEqual(result.returncode, 0, f"Command failed with stderr: {result.stderr}") + + # Verify the output contains expected config information + self.assertIn("critic", result.stdout) + self.assertIn("profiler", result.stdout) + self.assertIn("discrete=True", result.stdout) + self.assertIn("extra={'any_key': 'val'}", result.stdout) + + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/test_nvtx_profile.py b/tests/utils/test_nvtx_profile.py index 3450260c9..817d03000 100644 --- a/tests/utils/test_nvtx_profile.py +++ b/tests/utils/test_nvtx_profile.py @@ -50,6 +50,37 @@ class TestProfilerConfig(unittest.TestCase): with self.assertRaises(FrozenInstanceError): profiler_config.discrete = False + def test_frozen_config(self): + """Test that modifying frozen keys in ProfilerConfig raises exceptions.""" + from dataclasses import FrozenInstanceError + + from verl.utils.profiler.config import ProfilerConfig + + # Create a new ProfilerConfig instance + config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0]) + + # Test direct attribute assignment + with self.assertRaises(FrozenInstanceError): + config.discrete = False + + with self.assertRaises(FrozenInstanceError): + config.all_ranks = True + + with self.assertRaises(FrozenInstanceError): + config.ranks = [1, 2, 3] + + # Test dictionary-style assignment + with self.assertRaises(TypeError): + config["discrete"] = False + + with self.assertRaises(TypeError): + config["all_ranks"] = True + + with self.assertRaises(TypeError): + config["ranks"] = [1, 2, 3] + + config["extra"]["key"] = "value" + class TestNsightSystemsProfiler(unittest.TestCase): """Test suite for NsightSystemsProfiler functionality. diff --git a/verl/base_config.py b/verl/base_config.py index d413160de..0cd117bb6 100644 --- a/verl/base_config.py +++ b/verl/base_config.py @@ -13,11 +13,16 @@ # limitations under the License. import collections -from dataclasses import fields # Import the fields function to inspect dataclass fields +from dataclasses import ( + dataclass, + field, + fields, # Import the fields function to inspect dataclass fields +) from typing import Any # BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary +@dataclass class BaseConfig(collections.abc.Mapping): """The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config. @@ -25,6 +30,18 @@ class BaseConfig(collections.abc.Mapping): This allows instances of this class to be used like dictionaries. """ + extra: dict[str, Any] = field(default_factory=dict) + + def __setattr__(self, name: str, value): + # if the field already exists (i.e. was set in __init__) + # and is in our frozen list, block assignment + if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__: + from dataclasses import FrozenInstanceError + + raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified") + # otherwise do the normal thing + super().__setattr__(name, value) + def get(self, key: str, default: Any = None) -> Any: """Get the value associated with the given key. If the key does not exist, return the default value. diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml new file mode 100644 index 000000000..87b090c4e --- /dev/null +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -0,0 +1,361 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + strategy: fsdp + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + policy_loss: + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + entropy_coeff: 0 + use_kl_loss: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + optim: + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + min_lr_ratio: 0.0 + num_cycles: 0.5 + warmup_style: constant + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + ref: + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + fsdp_config: + param_offload: false + reshard_after_forward: true + forward_prefetch: false + wrap_policy: + min_num_params: 0 + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + name: vllm + mode: sync + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: true + free_cache_engine: true + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + multi_stage_wake_up: false + engine_kwargs: + vllm: + swap_space: null + disable_mm_preprocessor_cache: false + sglang: + attention_backend: null + val_kwargs: + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + completion_callback: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + calculate_log_probs: false + agent: + num_workers: 8 + custom_async_server: + path: null + name: null + enable_chunked_prefill: true + load_format: dummy_dtensor + layered_summon: false + hybrid_engine: true + model: + path: ~/models/deepseek-llm-7b-chat + custom_chat_template: null + use_shm: false + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + trust_remote_code: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] +trainer: + npu_profile: + options: + save_path: ./profiler_data + level: level1 + with_memory: false + record_shapes: false + with_npu: true + with_cpu: true + with_module: false + with_stack: false + analysis: true + balance_batch: true + total_epochs: 30 + total_training_steps: null + profile_steps: null + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + return_raw_input_ids: false + return_raw_chat: false + return_full_prompt: false + shuffle: true + dataloader_num_workers: 8 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null +critic: + rollout_n: ${actor_rollout_ref.rollout.n} + strategy: fsdp + optim: + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr: 1.0e-05 + min_lr_ratio: null + warmup_style: constant + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: {} + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: true + trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} + use_shm: false + enable_activation_offload: false + use_remove_padding: false + fsdp_config: + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + fsdp_size: -1 + forward_prefetch: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + cliprange_value: 0.5 + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +reward_model: + enable: false + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + use_shm: false + use_remove_padding: false + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + ulysses_sequence_parallel_size: 1 +custom_reward_function: + path: null + name: compute_score +algorithm: + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + _target_: verl.trainer.config.PFPPOConfig + reweight_method: pow + weight_pow: 2.0 +ray_init: + num_cpus: null + timeline_json_file: null diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index e9600a93b..5bc6cf943 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -18,7 +18,7 @@ from typing import Optional from verl.base_config import BaseConfig -@dataclass(frozen=True) +@dataclass class KLControlConfig(BaseConfig): """Configuration for KL control. @@ -31,13 +31,14 @@ class KLControlConfig(BaseConfig): target_kl (float): Target KL divergence for adaptive controller. """ + _frozen_fields = ["type", "kl_coef", "horizon", "target_kl"] type: str = "fixed" kl_coef: float = 0.001 horizon: int = 10000 target_kl: float = 0.1 -@dataclass(frozen=True) +@dataclass class PFPPOConfig(BaseConfig): """Configuration for preference feedback PPO. @@ -48,11 +49,12 @@ class PFPPOConfig(BaseConfig): weight_pow (float): Power used for weight scaling in "pow" method. """ + _frozen_fields = ["reweight_method", "weight_pow"] reweight_method: str = "pow" weight_pow: float = 2.0 -@dataclass(frozen=True) +@dataclass class FilterGroupsConfig(BaseConfig): """Configuration for filter groups (used in DAPO and Entropy). @@ -64,12 +66,14 @@ class FilterGroupsConfig(BaseConfig): max_num_gen_batches (int): Non-positive values mean no upper limit. """ + _frozen_fields = ["enable", "metric", "max_num_gen_batches"] + enable: bool = False metric: Optional[str] = None max_num_gen_batches: int = 0 -@dataclass(frozen=True) +@dataclass class AlgoConfig(BaseConfig): """Configuration for the algorithm. @@ -88,6 +92,16 @@ class AlgoConfig(BaseConfig): filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy """ + _frozen_fields = [ + "gamma", + "lam", + "adv_estimator", + "norm_adv_by_std_in_grpo", + "use_kl_in_reward", + "kl_penalty", + "use_pf_ppo", + ] + gamma: float = 1.0 lam: float = 1.0 adv_estimator: str = "gae" diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 3370019dc..914202256 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -1,4 +1,4 @@ -# actor_rollout_ref.rollout.name: hf/vllm/sglang. +# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future name: vllm # sync: LLM, async: AsyncLLM diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py index 8acf07502..d4fb53650 100644 --- a/verl/utils/profiler/config.py +++ b/verl/utils/profiler/config.py @@ -13,11 +13,12 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import ClassVar from verl.base_config import BaseConfig -@dataclass(frozen=True) +@dataclass class ProfilerConfig(BaseConfig): """Worker profiler config. Currently only support Nsight system profiler. @@ -30,6 +31,9 @@ class ProfilerConfig(BaseConfig): ranks (list[int]): The ranks that will be profiled. Defaults to []. """ + # the fields expected to be frozen + _frozen_fields: ClassVar[set[str]] = {"discrete", "all_ranks", "ranks"} + discrete: bool = False all_ranks: bool = False