Files
verl/tests/special_distributed/test_tensor_dict.py
H c936ec7d5c [trainer, cfg] feat: add BaseConfig for all dataclass configs. Introduce dataclass for algorithm related configs (#2147)
### What does this PR do?

This PR introduces a BaseConfig class that bridges dataclass and hydra's
DictConfig in the codebase. In this PR, the algorithm related configs
and profiler related configs are instantiated as dataclass upfront for
both main_ppo and main_dapo. The config related changes are expected to
be backward compatible (supporting xx_config.get() API)

Besides, this PR also moves the profiler related files under
verl.utils.debug to verl.utils.profiler.xx. The
`verl.utils.debug.performance.py` is kept for backward compatibility
purpose and we'll drop it in later versions.

Main principle:
- users are not forced to use dataclass configs. All changes are
backward compatible.
- dataclass configs are converted upfront on a per entrypoint basis.
Here we target main_ppo.py and main_dapo.py, and the other recipes'
entrypoints are left intact.
- the new dataclass are intentionally set to be frozen. Configs should
not be mutable. Whenever a new field is needed, we should make a copy of
the config for a new one.
- whenever a dataclass config is introduced, we encourage having simple
cpu-based unit tests to test the basic functionality of functions that
rely on it (e.g. the grpo adv estimation in core_algorithm.py). and then
also update all type annotation for the impacted functions.
- in the yaml file, `_target_` field should be specified for dataclass
conversion. e.g. `_target_: verl.xxx.XXConfig`

The PR is built on top of @liuzhenhai93 's contribution.

### Checklist Before Describing the Details

- [x] Searched for similar PR(s).
- [x] PR title is in the format of: `[modules] type: Title`
  - modules: `trainer, cfg`
  - type: `feat`

### Test

- Added comprehensive unit tests in
`tests/trainer/config/test_algorithm_config_on_cpu.py`,
`test_base_config_on_cpu.py`
- Tests cover dataclass creation, nested configuration handling,
backward compatibility, and integration with core algorithms
- All tests pass successfully, validating the functionality and
integration with existing code

### High-Level Design

The design introduces three dataclasses:
1. **`KLControlConfig`**: Handles KL control parameters (type, kl_coef,
horizon, target_kl)
2. **`PFPPOConfig`**: Manages preference feedback PPO parameters
(reweight_method, weight_pow)
3. **`AlgorithmConfig`**: Main algorithm configuration containing all
fields from the YAML config

The conversion uses the existing `verl.utils.omega_conf_to_dataclass`
utility to seamlessly convert from OmegaConf DictConfig to typed
dataclasses.


### API and Usage Example

The API maintains backward compatibility while providing type-safe
access:

```python
# Before (DictConfig)
if config.algorithm.use_kl_in_reward:
    kl_penalty = config.algorithm.kl_penalty
    kl_coef = config.algorithm.kl_ctrl.get("kl_coef", 0.001)

# After (Dataclass) - Type-safe with IDE support
algorithm_config = omega_conf_to_dataclass(config.algorithm)
if algorithm_config.use_kl_in_reward:
    kl_penalty = algorithm_config.kl_penalty  # Type-safe access
    kl_coef = algorithm_config.kl_ctrl.kl_coef  # Nested config access

# Backward compatibility maintained
gamma = algorithm_config.get("gamma", 1.0)  # Still works


# other cases
profiler_config = omega_conf_to_dataclass(config)
self.assertEqual(profiler_config.discrete, config.discrete)
self.assertEqual(profiler_config.all_ranks, config.all_ranks)
self.assertEqual(profiler_config.ranks, config.ranks)
assert isinstance(profiler_config, ProfilerConfig)
with self.assertRaises(AttributeError):
    _ = profiler_config.non_existing_key
assert config.get("non_existing_key") == profiler_config.get("non_existing_key")
assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1)
assert config["discrete"] == profiler_config["discrete"]
from dataclasses import FrozenInstanceError

with self.assertRaises(FrozenInstanceError):
    profiler_config.discrete = False

```

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit run --show-diff-on-failure --color=always --all-files`
- [ ] Add `[BREAKING]` to the PR title `description` if it breaks any
API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] New CI unit test(s) are added to cover the code path.
- [x] Rely on existing unit tests on CI that covers the code path.

**Note**: This change is fully backward compatible and does not break
any existing APIs. The dataclass provides the same interface as the
original DictConfig while adding type safety and better structure.

---------

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
2025-07-04 08:12:09 -07:00

122 lines
4.7 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["NCCL_DEBUG"] = "WARN"
import numpy as np
import torch
import torch.distributed
from verl.protocol import DataProto, all_gather_data_proto
from verl.utils.distributed import initialize_global_process_group
def test_all_gather_data_proto():
device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"])
global_rank = torch.distributed.get_rank()
obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])
labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"]
labels = np.array(labels, dtype=object)
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"})
all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp"))
if global_rank == 0:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
expected_labels = ["a", "b", "a", "b"]
elif global_rank == 1:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
expected_labels = ["b", "a", "b", "a"]
elif global_rank == 2:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda")
expected_labels = ["a", "b", "a", "b"]
elif global_rank == 3:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda")
expected_labels = ["b", "a", "b", "a"]
torch.testing.assert_close(data.batch["obs"], expected_obs, atol=0, rtol=0)
assert (data.non_tensor_batch["labels"] == expected_labels).all()
assert data.meta_info == {"info": "test_info"}
def test_vocab_parallel_entropy():
from megatron.core import parallel_state as mpu
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy
from verl.utils.profiler import log_gpu_memory_usage
from verl.utils.torch_functional import entropy_from_logits
mpu.initialize_model_parallel(
tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None
)
batch_size = 2
seqlen = 128
vocab_size = 155136
logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True)
target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64)
# broadcast across tp
torch.distributed.broadcast(
logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()
)
torch.distributed.broadcast(
target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()
)
tp_rank = mpu.get_tensor_model_parallel_rank()
vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size()
# get the local logits of each tp
vocab_parallel_logits = (
logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_()
)
logits.grad = None
vocab_parallel_logits.grad = None
log_gpu_memory_usage("begin")
output_entropy = vocab_parallel_entropy(vocab_parallel_logits)
log_gpu_memory_usage("after forward")
grad_output = torch.randn_like(output_entropy)
output_entropy.backward(grad_output)
log_gpu_memory_usage("after backward")
target_entropy = entropy_from_logits(logits)
torch.testing.assert_close(output_entropy, target_entropy)
target_entropy.backward(grad_output)
torch.testing.assert_close(
logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad
)
# make sure logits is not altered
torch.testing.assert_close(
logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits
)
if mpu.get_tensor_model_parallel_rank() == 0:
print("test_vocab_parallel_entropy passes")
mpu.destroy_model_parallel()
if __name__ == "__main__":
local_rank, rank, world_size = initialize_global_process_group()
test_all_gather_data_proto()
test_vocab_parallel_entropy()