mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
### What does this PR do? This PR introduces a BaseConfig class that bridges dataclass and hydra's DictConfig in the codebase. In this PR, the algorithm related configs and profiler related configs are instantiated as dataclass upfront for both main_ppo and main_dapo. The config related changes are expected to be backward compatible (supporting xx_config.get() API) Besides, this PR also moves the profiler related files under verl.utils.debug to verl.utils.profiler.xx. The `verl.utils.debug.performance.py` is kept for backward compatibility purpose and we'll drop it in later versions. Main principle: - users are not forced to use dataclass configs. All changes are backward compatible. - dataclass configs are converted upfront on a per entrypoint basis. Here we target main_ppo.py and main_dapo.py, and the other recipes' entrypoints are left intact. - the new dataclass are intentionally set to be frozen. Configs should not be mutable. Whenever a new field is needed, we should make a copy of the config for a new one. - whenever a dataclass config is introduced, we encourage having simple cpu-based unit tests to test the basic functionality of functions that rely on it (e.g. the grpo adv estimation in core_algorithm.py). and then also update all type annotation for the impacted functions. - in the yaml file, `_target_` field should be specified for dataclass conversion. e.g. `_target_: verl.xxx.XXConfig` The PR is built on top of @liuzhenhai93 's contribution. ### Checklist Before Describing the Details - [x] Searched for similar PR(s). - [x] PR title is in the format of: `[modules] type: Title` - modules: `trainer, cfg` - type: `feat` ### Test - Added comprehensive unit tests in `tests/trainer/config/test_algorithm_config_on_cpu.py`, `test_base_config_on_cpu.py` - Tests cover dataclass creation, nested configuration handling, backward compatibility, and integration with core algorithms - All tests pass successfully, validating the functionality and integration with existing code ### High-Level Design The design introduces three dataclasses: 1. **`KLControlConfig`**: Handles KL control parameters (type, kl_coef, horizon, target_kl) 2. **`PFPPOConfig`**: Manages preference feedback PPO parameters (reweight_method, weight_pow) 3. **`AlgorithmConfig`**: Main algorithm configuration containing all fields from the YAML config The conversion uses the existing `verl.utils.omega_conf_to_dataclass` utility to seamlessly convert from OmegaConf DictConfig to typed dataclasses. ### API and Usage Example The API maintains backward compatibility while providing type-safe access: ```python # Before (DictConfig) if config.algorithm.use_kl_in_reward: kl_penalty = config.algorithm.kl_penalty kl_coef = config.algorithm.kl_ctrl.get("kl_coef", 0.001) # After (Dataclass) - Type-safe with IDE support algorithm_config = omega_conf_to_dataclass(config.algorithm) if algorithm_config.use_kl_in_reward: kl_penalty = algorithm_config.kl_penalty # Type-safe access kl_coef = algorithm_config.kl_ctrl.kl_coef # Nested config access # Backward compatibility maintained gamma = algorithm_config.get("gamma", 1.0) # Still works # other cases profiler_config = omega_conf_to_dataclass(config) self.assertEqual(profiler_config.discrete, config.discrete) self.assertEqual(profiler_config.all_ranks, config.all_ranks) self.assertEqual(profiler_config.ranks, config.ranks) assert isinstance(profiler_config, ProfilerConfig) with self.assertRaises(AttributeError): _ = profiler_config.non_existing_key assert config.get("non_existing_key") == profiler_config.get("non_existing_key") assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) assert config["discrete"] == profiler_config["discrete"] from dataclasses import FrozenInstanceError with self.assertRaises(FrozenInstanceError): profiler_config.discrete = False ``` ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting): `pre-commit run --show-diff-on-failure --color=always --all-files` - [ ] Add `[BREAKING]` to the PR title `description` if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] New CI unit test(s) are added to cover the code path. - [x] Rely on existing unit tests on CI that covers the code path. **Note**: This change is fully backward compatible and does not break any existing APIs. The dataclass provides the same interface as the original DictConfig while adding type safety and better structure. --------- Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import pytest
|
|
|
|
from verl.base_config import BaseConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def base_config_mock():
|
|
"""Fixture to create a mock BaseConfig instance with test attributes."""
|
|
mock_config = BaseConfig()
|
|
mock_config.test_attr = "test_value"
|
|
return mock_config
|
|
|
|
|
|
def test_getitem_success(base_config_mock):
|
|
"""Test __getitem__ with existing attribute (happy path)."""
|
|
assert base_config_mock["test_attr"] == "test_value"
|
|
|
|
|
|
def test_getitem_nonexistent_attribute(base_config_mock):
|
|
"""Test __getitem__ with non-existent attribute (exception path 1)."""
|
|
with pytest.raises(AttributeError):
|
|
_ = base_config_mock["nonexistent_attr"]
|
|
|
|
|
|
def test_getitem_invalid_key_type(base_config_mock):
|
|
"""Test __getitem__ with invalid key type (exception path 2)."""
|
|
with pytest.raises(TypeError):
|
|
_ = base_config_mock[123] # type: ignore
|