mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
### What does this PR do? This PR introduces a BaseConfig class that bridges dataclass and hydra's DictConfig in the codebase. In this PR, the algorithm related configs and profiler related configs are instantiated as dataclass upfront for both main_ppo and main_dapo. The config related changes are expected to be backward compatible (supporting xx_config.get() API) Besides, this PR also moves the profiler related files under verl.utils.debug to verl.utils.profiler.xx. The `verl.utils.debug.performance.py` is kept for backward compatibility purpose and we'll drop it in later versions. Main principle: - users are not forced to use dataclass configs. All changes are backward compatible. - dataclass configs are converted upfront on a per entrypoint basis. Here we target main_ppo.py and main_dapo.py, and the other recipes' entrypoints are left intact. - the new dataclass are intentionally set to be frozen. Configs should not be mutable. Whenever a new field is needed, we should make a copy of the config for a new one. - whenever a dataclass config is introduced, we encourage having simple cpu-based unit tests to test the basic functionality of functions that rely on it (e.g. the grpo adv estimation in core_algorithm.py). and then also update all type annotation for the impacted functions. - in the yaml file, `_target_` field should be specified for dataclass conversion. e.g. `_target_: verl.xxx.XXConfig` The PR is built on top of @liuzhenhai93 's contribution. ### Checklist Before Describing the Details - [x] Searched for similar PR(s). - [x] PR title is in the format of: `[modules] type: Title` - modules: `trainer, cfg` - type: `feat` ### Test - Added comprehensive unit tests in `tests/trainer/config/test_algorithm_config_on_cpu.py`, `test_base_config_on_cpu.py` - Tests cover dataclass creation, nested configuration handling, backward compatibility, and integration with core algorithms - All tests pass successfully, validating the functionality and integration with existing code ### High-Level Design The design introduces three dataclasses: 1. **`KLControlConfig`**: Handles KL control parameters (type, kl_coef, horizon, target_kl) 2. **`PFPPOConfig`**: Manages preference feedback PPO parameters (reweight_method, weight_pow) 3. **`AlgorithmConfig`**: Main algorithm configuration containing all fields from the YAML config The conversion uses the existing `verl.utils.omega_conf_to_dataclass` utility to seamlessly convert from OmegaConf DictConfig to typed dataclasses. ### API and Usage Example The API maintains backward compatibility while providing type-safe access: ```python # Before (DictConfig) if config.algorithm.use_kl_in_reward: kl_penalty = config.algorithm.kl_penalty kl_coef = config.algorithm.kl_ctrl.get("kl_coef", 0.001) # After (Dataclass) - Type-safe with IDE support algorithm_config = omega_conf_to_dataclass(config.algorithm) if algorithm_config.use_kl_in_reward: kl_penalty = algorithm_config.kl_penalty # Type-safe access kl_coef = algorithm_config.kl_ctrl.kl_coef # Nested config access # Backward compatibility maintained gamma = algorithm_config.get("gamma", 1.0) # Still works # other cases profiler_config = omega_conf_to_dataclass(config) self.assertEqual(profiler_config.discrete, config.discrete) self.assertEqual(profiler_config.all_ranks, config.all_ranks) self.assertEqual(profiler_config.ranks, config.ranks) assert isinstance(profiler_config, ProfilerConfig) with self.assertRaises(AttributeError): _ = profiler_config.non_existing_key assert config.get("non_existing_key") == profiler_config.get("non_existing_key") assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) assert config["discrete"] == profiler_config["discrete"] from dataclasses import FrozenInstanceError with self.assertRaises(FrozenInstanceError): profiler_config.discrete = False ``` ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting): `pre-commit run --show-diff-on-failure --color=always --all-files` - [ ] Add `[BREAKING]` to the PR title `description` if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] New CI unit test(s) are added to cover the code path. - [x] Rely on existing unit tests on CI that covers the code path. **Note**: This change is fully backward compatible and does not break any existing APIs. The dataclass provides the same interface as the original DictConfig while adding type safety and better structure. --------- Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
122 lines
4.7 KiB
Python
122 lines
4.7 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
os.environ["NCCL_DEBUG"] = "WARN"
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed
|
|
|
|
from verl.protocol import DataProto, all_gather_data_proto
|
|
from verl.utils.distributed import initialize_global_process_group
|
|
|
|
|
|
def test_all_gather_data_proto():
|
|
device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"])
|
|
|
|
global_rank = torch.distributed.get_rank()
|
|
|
|
obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])
|
|
|
|
labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"]
|
|
labels = np.array(labels, dtype=object)
|
|
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
|
|
|
|
all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp"))
|
|
|
|
if global_rank == 0:
|
|
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
|
|
expected_labels = ["a", "b", "a", "b"]
|
|
elif global_rank == 1:
|
|
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
|
|
expected_labels = ["b", "a", "b", "a"]
|
|
elif global_rank == 2:
|
|
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
|
|
expected_labels = ["a", "b", "a", "b"]
|
|
elif global_rank == 3:
|
|
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
|
|
expected_labels = ["b", "a", "b", "a"]
|
|
|
|
torch.testing.assert_close(data.batch["obs"], expected_obs, atol=0, rtol=0)
|
|
assert (data.non_tensor_batch["labels"] == expected_labels).all()
|
|
assert data.meta_info == {"info": "test_info"}
|
|
|
|
|
|
def test_vocab_parallel_entropy():
|
|
from megatron.core import parallel_state as mpu
|
|
|
|
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy
|
|
from verl.utils.profiler import log_gpu_memory_usage
|
|
from verl.utils.torch_functional import entropy_from_logits
|
|
|
|
mpu.initialize_model_parallel(
|
|
tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None
|
|
)
|
|
|
|
batch_size = 2
|
|
seqlen = 128
|
|
vocab_size = 155136
|
|
|
|
logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True)
|
|
target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64)
|
|
|
|
# broadcast across tp
|
|
torch.distributed.broadcast(
|
|
logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()
|
|
)
|
|
torch.distributed.broadcast(
|
|
target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()
|
|
)
|
|
|
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
|
vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size()
|
|
|
|
# get the local logits of each tp
|
|
vocab_parallel_logits = (
|
|
logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_()
|
|
)
|
|
logits.grad = None
|
|
vocab_parallel_logits.grad = None
|
|
|
|
log_gpu_memory_usage("begin")
|
|
output_entropy = vocab_parallel_entropy(vocab_parallel_logits)
|
|
log_gpu_memory_usage("after forward")
|
|
grad_output = torch.randn_like(output_entropy)
|
|
output_entropy.backward(grad_output)
|
|
log_gpu_memory_usage("after backward")
|
|
|
|
target_entropy = entropy_from_logits(logits)
|
|
torch.testing.assert_close(output_entropy, target_entropy)
|
|
target_entropy.backward(grad_output)
|
|
torch.testing.assert_close(
|
|
logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad
|
|
)
|
|
# make sure logits is not altered
|
|
torch.testing.assert_close(
|
|
logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits
|
|
)
|
|
|
|
if mpu.get_tensor_model_parallel_rank() == 0:
|
|
print("test_vocab_parallel_entropy passes")
|
|
|
|
mpu.destroy_model_parallel()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
local_rank, rank, world_size = initialize_global_process_group()
|
|
test_all_gather_data_proto()
|
|
test_vocab_parallel_entropy()
|