mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[util] docs: add docstrings to metric util functions that recipes reuse (#1395)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? In `/recipes`, a few functions under `trainer/ppo/metric_utils` are imported and reused. Right now many of them are task dependent and assume specific keys in the input metric dict. To make these functions more robust and backward compatible, a few tests are added. Additionally, one method is moved to verl.utils as a public API due to its general purpose nature. A API doc page is added correspondingly. In order to make it easy for others to customize verl trainers, many more other classes require further documentations, such as: - AdvantageEstimator, RayPPOTrainer, apply_kl_penalty, compute_advantage - from verl.single_controller.ray import RayWorkerGroup - from verl.trainer.ppo.core_algos import agg_loss - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role, WorkerType - from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path They shall be enhanced in future PRs. ### High-Level Design None ### Specific Changes - added tests - added verl.utils.metric namespace ### API `verl.trainer.ppo.metric_utils.reduce_metrics` changed to `verl.utils.metric.reduce_metrics`. deprecation warnings are added. ### Usage Example None ### Test Added ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. https://github.com/volcengine/verl/issues/1354 - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary. --------- Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
5
.github/workflows/e2e_ascend.yml
vendored
5
.github/workflows/e2e_ascend.yml
vendored
@ -14,6 +14,11 @@ on:
|
||||
- "**/*.py"
|
||||
- .github/workflows/e2e_ascend.yml
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
|
6
.github/workflows/e2e_dapo.yml
vendored
6
.github/workflows/e2e_dapo.yml
vendored
@ -29,6 +29,12 @@ on:
|
||||
- "examples/data_preprocess/gsm8k.py"
|
||||
- "tests/e2e/run_dapo.sh"
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
|
||||
# Declare permissions just read content.
|
||||
permissions:
|
||||
contents: read
|
||||
|
5
.github/workflows/e2e_eval_aime24.yml
vendored
5
.github/workflows/e2e_eval_aime24.yml
vendored
@ -28,6 +28,11 @@ on:
|
||||
- "verl/trainer/main_generation.py"
|
||||
- "verl/trainer/config/generation.yaml"
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
# Declare permissions just read content.
|
||||
permissions:
|
||||
contents: read
|
||||
|
7
.github/workflows/e2e_prime.yml
vendored
7
.github/workflows/e2e_prime.yml
vendored
@ -29,6 +29,11 @@ on:
|
||||
- "examples/data_preprocess/gsm8k.py"
|
||||
- "tests/e2e/run_prime.sh"
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
# Declare permissions just read content.
|
||||
permissions:
|
||||
contents: read
|
||||
@ -36,7 +41,7 @@ permissions:
|
||||
jobs:
|
||||
e2e_prime:
|
||||
runs-on: [L20x8]
|
||||
timeout-minutes: 40 # Increase this timeout value as needed
|
||||
timeout-minutes: 50 # Increase this timeout value as needed
|
||||
env:
|
||||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
|
||||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
|
||||
|
5
.github/workflows/verl_unit_test.yml
vendored
5
.github/workflows/verl_unit_test.yml
vendored
@ -49,3 +49,8 @@ jobs:
|
||||
run: |
|
||||
cd tests/utils
|
||||
pytest -s -x --ignore=dataset/ --ignore=checkpoint/ --ignore=test_flops_counter.py --ignore=test_torch_functional.py .
|
||||
cd -
|
||||
- name: Running trainer tests
|
||||
run: |
|
||||
cd tests/trainer
|
||||
pytest -s -x .
|
||||
|
8
docs/api/utils.rst
Normal file
8
docs/api/utils.rst
Normal file
@ -0,0 +1,8 @@
|
||||
Training utils
|
||||
=========================
|
||||
|
||||
Core APIs
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: verl.utils.metric
|
||||
:members: reduce_metrics
|
@ -60,6 +60,13 @@ verl is fast with:
|
||||
examples/gsm8k_example
|
||||
examples/multi_modal_example
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Algorithms
|
||||
|
||||
experiment/ppo
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: PPO Trainer and Workers
|
||||
@ -77,12 +84,6 @@ verl is fast with:
|
||||
README_vllm0.8.md
|
||||
perf/device_tuning
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Experimental Results
|
||||
|
||||
experiment/ppo
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Advance Usage and Extension
|
||||
@ -92,12 +93,14 @@ verl is fast with:
|
||||
advance/fsdp_extension
|
||||
advance/megatron_extension
|
||||
advance/checkpoint
|
||||
sglang_multiturn/multiturn.rst
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: API References
|
||||
|
||||
data.rst
|
||||
data
|
||||
api/utils
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
@ -30,8 +30,8 @@ from verl.trainer.ppo.ray_trainer import (
|
||||
compute_advantage,
|
||||
compute_data_metrics,
|
||||
compute_timing_metrics,
|
||||
reduce_metrics,
|
||||
)
|
||||
from verl.utils.metric import reduce_metrics
|
||||
|
||||
|
||||
def fit(self):
|
||||
|
@ -30,7 +30,8 @@ from verl import DataProto
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
from verl.trainer.ppo.core_algos import agg_loss
|
||||
from verl.trainer.ppo.metric_utils import _compute_response_info
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, reduce_metrics
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer
|
||||
from verl.utils.metric import reduce_metrics
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
||||
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
||||
|
||||
|
@ -78,6 +78,8 @@ python3 -m recipe.dapo.main_dapo \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
trainer.logger=['console'] \
|
||||
trainer.project_name='verl-test' \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node=${NUM_GPUS} \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=-1 \
|
||||
|
16
tests/trainer/__init__.py
Normal file
16
tests/trainer/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for the trainer module.
|
||||
"""
|
16
tests/trainer/ppo/__init__.py
Normal file
16
tests/trainer/ppo/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for the PPO trainer module.
|
||||
"""
|
319
tests/trainer/ppo/test_metric_utils.py
Normal file
319
tests/trainer/ppo/test_metric_utils.py
Normal file
@ -0,0 +1,319 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for the metric utilities in verl.trainer.ppo.metric_utils.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from verl.trainer.ppo.metric_utils import (
|
||||
bootstrap_metric,
|
||||
calc_maj_val,
|
||||
compute_data_metrics,
|
||||
compute_throughout_metrics,
|
||||
compute_timing_metrics,
|
||||
process_validation_metrics,
|
||||
)
|
||||
|
||||
from verl.utils.metric import (
|
||||
reduce_metrics,
|
||||
)
|
||||
|
||||
|
||||
class TestReduceMetrics(unittest.TestCase):
|
||||
"""Tests for the reduce_metrics function."""
|
||||
|
||||
def test_reduce_metrics_basic(self):
|
||||
"""Test that reduce_metrics correctly computes means."""
|
||||
metrics = {
|
||||
"loss": [1.0, 2.0, 3.0],
|
||||
"accuracy": [0.0, 0.5, 1.0],
|
||||
}
|
||||
result = reduce_metrics(metrics)
|
||||
|
||||
self.assertEqual(result["loss"], 2.0)
|
||||
self.assertEqual(result["accuracy"], 0.5)
|
||||
|
||||
def test_reduce_metrics_empty(self):
|
||||
"""Test that reduce_metrics handles empty lists."""
|
||||
metrics = {
|
||||
"empty": [],
|
||||
}
|
||||
result = reduce_metrics(metrics)
|
||||
|
||||
self.assertTrue(np.isnan(result["empty"]))
|
||||
|
||||
def test_reduce_metrics_single_value(self):
|
||||
"""Test that reduce_metrics works with single values."""
|
||||
metrics = {
|
||||
"single": [5.0],
|
||||
}
|
||||
result = reduce_metrics(metrics)
|
||||
|
||||
self.assertEqual(result["single"], 5.0)
|
||||
|
||||
|
||||
class TestComputeDataMetrics(unittest.TestCase):
|
||||
"""Tests for the compute_data_metrics function."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common test data."""
|
||||
# Create a mock DataProto object
|
||||
self.batch = MagicMock()
|
||||
self.batch.batch = {
|
||||
"token_level_scores": torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
|
||||
"token_level_rewards": torch.tensor([[0.5, 1.0], [1.5, 2.0]]),
|
||||
"advantages": torch.tensor([[0.1, 0.2], [0.3, 0.4]]),
|
||||
"returns": torch.tensor([[1.1, 1.2], [1.3, 1.4]]),
|
||||
"responses": torch.zeros((2, 2)), # 2 samples, 2 tokens each
|
||||
"attention_mask": torch.tensor([
|
||||
[1, 1, 1, 1], # 2 prompt tokens, 2 response tokens
|
||||
[1, 1, 1, 1],
|
||||
]),
|
||||
"values": torch.tensor([[0.9, 1.0], [1.1, 1.2]]),
|
||||
}
|
||||
|
||||
def test_compute_data_metrics_with_critic(self):
|
||||
"""Test compute_data_metrics with critic enabled."""
|
||||
metrics = compute_data_metrics(self.batch, use_critic=True)
|
||||
|
||||
# Check that all expected metrics are present
|
||||
self.assertIn("critic/score/mean", metrics)
|
||||
self.assertIn("critic/rewards/mean", metrics)
|
||||
self.assertIn("critic/advantages/mean", metrics)
|
||||
self.assertIn("critic/returns/mean", metrics)
|
||||
self.assertIn("critic/values/mean", metrics)
|
||||
self.assertIn("critic/vf_explained_var", metrics)
|
||||
self.assertIn("response_length/mean", metrics)
|
||||
self.assertIn("prompt_length/mean", metrics)
|
||||
|
||||
# Check some specific values
|
||||
self.assertAlmostEqual(metrics["critic/score/mean"], 5.0) # Sum of token_level_scores
|
||||
self.assertAlmostEqual(metrics["critic/rewards/mean"], 2.5) # Sum of token_level_rewards
|
||||
|
||||
def test_compute_data_metrics_without_critic(self):
|
||||
"""Test compute_data_metrics with critic disabled."""
|
||||
metrics = compute_data_metrics(self.batch, use_critic=False)
|
||||
|
||||
# Check that critic-specific metrics are not present
|
||||
self.assertNotIn("critic/values/mean", metrics)
|
||||
self.assertNotIn("critic/vf_explained_var", metrics)
|
||||
|
||||
# Check that other metrics are still present
|
||||
self.assertIn("critic/score/mean", metrics)
|
||||
self.assertIn("critic/rewards/mean", metrics)
|
||||
self.assertIn("response_length/mean", metrics)
|
||||
|
||||
|
||||
class TestComputeTimingMetrics(unittest.TestCase):
|
||||
"""Tests for the compute_timing_metrics function."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common test data."""
|
||||
# Create a mock DataProto object
|
||||
self.batch = MagicMock()
|
||||
self.batch.batch = {
|
||||
"responses": torch.zeros((2, 3)), # 2 samples, 3 response tokens each
|
||||
"attention_mask": torch.tensor([
|
||||
[1, 1, 1, 1, 1, 1], # 3 prompt tokens, 3 response tokens
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
]),
|
||||
}
|
||||
|
||||
# Mock the _compute_response_info function to return known values
|
||||
self.response_info = {
|
||||
"prompt_length": torch.tensor([3.0, 3.0]),
|
||||
"response_length": torch.tensor([3.0, 3.0]),
|
||||
"response_mask": torch.ones((2, 3)),
|
||||
}
|
||||
|
||||
@patch("verl.trainer.ppo.metric_utils._compute_response_info")
|
||||
def test_compute_timing_metrics(self, mock_compute_response_info):
|
||||
"""Test compute_timing_metrics with various timing data."""
|
||||
mock_compute_response_info.return_value = self.response_info
|
||||
|
||||
timing_raw = {
|
||||
"gen": 0.5, # 500ms
|
||||
"ref": 0.3, # 300ms
|
||||
"values": 0.2, # 200ms
|
||||
}
|
||||
|
||||
metrics = compute_timing_metrics(self.batch, timing_raw)
|
||||
|
||||
# Check raw timing metrics
|
||||
self.assertEqual(metrics["timing_s/gen"], 0.5)
|
||||
self.assertEqual(metrics["timing_s/ref"], 0.3)
|
||||
self.assertEqual(metrics["timing_s/values"], 0.2)
|
||||
|
||||
# Check per-token timing metrics
|
||||
# gen uses only response tokens (6 tokens)
|
||||
self.assertAlmostEqual(metrics["timing_per_token_ms/gen"], 0.5 * 1000 / 6, places=5)
|
||||
|
||||
# ref and values use all tokens (12 tokens)
|
||||
self.assertAlmostEqual(metrics["timing_per_token_ms/ref"], 0.3 * 1000 / 12, places=5)
|
||||
self.assertAlmostEqual(metrics["timing_per_token_ms/values"], 0.2 * 1000 / 12, places=5)
|
||||
|
||||
|
||||
class TestComputeThroughputMetrics(unittest.TestCase):
|
||||
"""Tests for the compute_throughout_metrics function."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common test data."""
|
||||
# Create a mock DataProto object
|
||||
self.batch = MagicMock()
|
||||
self.batch.meta_info = {
|
||||
"global_token_num": [100, 200, 300], # 600 tokens total
|
||||
}
|
||||
|
||||
def test_compute_throughout_metrics(self):
|
||||
"""Test compute_throughout_metrics with various timing data."""
|
||||
timing_raw = {
|
||||
"step": 2.0, # 2 seconds per step
|
||||
}
|
||||
|
||||
# Test with 1 GPU
|
||||
metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=1)
|
||||
|
||||
self.assertEqual(metrics["perf/total_num_tokens"], 600)
|
||||
self.assertEqual(metrics["perf/time_per_step"], 2.0)
|
||||
self.assertEqual(metrics["perf/throughput"], 600 / 2.0) # 300 tokens/sec
|
||||
|
||||
# Test with 2 GPUs
|
||||
metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=2)
|
||||
|
||||
self.assertEqual(metrics["perf/total_num_tokens"], 600)
|
||||
self.assertEqual(metrics["perf/time_per_step"], 2.0)
|
||||
self.assertEqual(metrics["perf/throughput"], 600 / (2.0 * 2)) # 150 tokens/sec/GPU
|
||||
|
||||
|
||||
class TestBootstrapMetric(unittest.TestCase):
|
||||
"""Tests for the bootstrap_metric function."""
|
||||
|
||||
def test_bootstrap_metric_basic(self):
|
||||
"""Test bootstrap_metric with simple data and functions."""
|
||||
data = [1, 2, 3, 4, 5]
|
||||
reduce_fns = [np.mean, np.max]
|
||||
|
||||
# Use a fixed seed for reproducibility
|
||||
result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42)
|
||||
|
||||
# Check that we get two results (one for each reduce_fn)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
# Each result should be a tuple of (mean, std)
|
||||
mean_result, max_result = result
|
||||
self.assertEqual(len(mean_result), 2)
|
||||
self.assertEqual(len(max_result), 2)
|
||||
|
||||
# The mean of means should be close to the true mean (3.0)
|
||||
self.assertAlmostEqual(mean_result[0], 3.0, delta=0.3)
|
||||
|
||||
# The mean of maxes should be close to the expected value for samples of size 3
|
||||
# For samples of size 3 from [1,2,3,4,5], the expected max is around 4.0-4.5
|
||||
self.assertGreater(max_result[0], 3.5)
|
||||
self.assertLess(max_result[0], 5.0)
|
||||
|
||||
def test_bootstrap_metric_empty(self):
|
||||
"""Test bootstrap_metric with empty data."""
|
||||
with self.assertRaises(ValueError):
|
||||
bootstrap_metric([], subset_size=1, reduce_fns=[np.mean])
|
||||
|
||||
|
||||
class TestCalcMajVal(unittest.TestCase):
|
||||
"""Tests for the calc_maj_val function."""
|
||||
|
||||
def test_calc_maj_val_basic(self):
|
||||
"""Test calc_maj_val with simple data."""
|
||||
data = [
|
||||
{"pred": "A", "val": 0.9},
|
||||
{"pred": "B", "val": 0.8},
|
||||
{"pred": "A", "val": 0.7},
|
||||
]
|
||||
|
||||
result = calc_maj_val(data, vote_key="pred", val_key="val")
|
||||
|
||||
# "A" is the majority vote, so we should get the first "val" for "A"
|
||||
self.assertEqual(result, 0.9)
|
||||
|
||||
def test_calc_maj_val_tie(self):
|
||||
"""Test calc_maj_val with tied votes."""
|
||||
data = [
|
||||
{"pred": "A", "val": 0.9},
|
||||
{"pred": "B", "val": 0.8},
|
||||
{"pred": "B", "val": 0.7},
|
||||
{"pred": "A", "val": 0.6},
|
||||
]
|
||||
|
||||
# In case of a tie, the first key in sorted order wins
|
||||
# This depends on Python's dict implementation, but for this test
|
||||
# we just verify that one of the valid values is returned
|
||||
result = calc_maj_val(data, vote_key="pred", val_key="val")
|
||||
|
||||
self.assertTrue(result in [0.9, 0.8])
|
||||
|
||||
|
||||
class TestProcessValidationMetrics(unittest.TestCase):
|
||||
"""Tests for the process_validation_metrics function."""
|
||||
|
||||
def test_process_validation_metrics_basic(self):
|
||||
"""Test process_validation_metrics with simple data."""
|
||||
data_sources = ["source1", "source1", "source2"]
|
||||
sample_inputs = ["prompt1", "prompt1", "prompt2"]
|
||||
infos_dict = {
|
||||
"score": [0.8, 0.9, 0.7],
|
||||
}
|
||||
|
||||
result = process_validation_metrics(
|
||||
data_sources, sample_inputs, infos_dict, seed=42
|
||||
)
|
||||
|
||||
# Check the structure of the result
|
||||
self.assertIn("source1", result)
|
||||
self.assertIn("source2", result)
|
||||
|
||||
# Check that source1 has metrics for score
|
||||
self.assertIn("score", result["source1"])
|
||||
|
||||
# Check that mean@2 is present for source1/score
|
||||
self.assertIn("mean@2", result["source1"]["score"])
|
||||
|
||||
# Check the value of mean@2 for source1/score
|
||||
self.assertAlmostEqual(result["source1"]["score"]["mean@2"], 0.85)
|
||||
|
||||
def test_process_validation_metrics_with_pred(self):
|
||||
"""Test process_validation_metrics with prediction data."""
|
||||
data_sources = ["source1", "source1", "source1"]
|
||||
sample_inputs = ["prompt1", "prompt1", "prompt1"]
|
||||
infos_dict = {
|
||||
"score": [0.8, 0.9, 0.7],
|
||||
"pred": ["A", "B", "A"],
|
||||
}
|
||||
|
||||
result = process_validation_metrics(
|
||||
data_sources, sample_inputs, infos_dict, seed=42
|
||||
)
|
||||
|
||||
# Check that majority voting metrics are present
|
||||
self.assertIn("maj@2/mean", result["source1"]["score"])
|
||||
|
||||
# For bootstrap with n=2, the majority vote could be either A or B
|
||||
# depending on the random sampling, so we don't check the exact value
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -23,15 +23,44 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.import_utils import deprecated
|
||||
|
||||
|
||||
@deprecated("verl.utils.metric.reduce_metrics")
|
||||
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
for key, val in metrics.items():
|
||||
metrics[key] = np.mean(val)
|
||||
return metrics
|
||||
"""
|
||||
Reduces a dictionary of metric lists by computing the mean of each list.
|
||||
|
||||
Args:
|
||||
metrics: A dictionary mapping metric names to lists of metric values.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same keys but with each list replaced by its mean value.
|
||||
|
||||
Example:
|
||||
>>> metrics = {"loss": [1.0, 2.0, 3.0], "accuracy": [0.8, 0.9, 0.7]}
|
||||
>>> reduce_metrics(metrics)
|
||||
{"loss": 2.0, "accuracy": 0.8}
|
||||
"""
|
||||
from verl.utils.metric import reduce_metrics
|
||||
|
||||
return reduce_metrics(metrics)
|
||||
|
||||
|
||||
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
|
||||
"""
|
||||
Computes information about prompts and responses from a batch.
|
||||
|
||||
This is an internal helper function that extracts masks and lengths for prompts and responses.
|
||||
|
||||
Args:
|
||||
batch: A DataProto object containing batch data with responses and attention masks.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- response_mask: Attention mask for the response tokens
|
||||
- prompt_length: Tensor of prompt lengths for each item in the batch
|
||||
- response_length: Tensor of response lengths for each item in the batch
|
||||
"""
|
||||
response_length = batch.batch["responses"].shape[-1]
|
||||
|
||||
prompt_mask = batch.batch["attention_mask"][:, :-response_length]
|
||||
@ -48,7 +77,28 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]:
|
||||
# TODO: add response length
|
||||
"""
|
||||
Computes various metrics from a batch of data for PPO training.
|
||||
|
||||
This function calculates metrics related to scores, rewards, advantages, returns, values,
|
||||
and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
|
||||
for each metric category.
|
||||
|
||||
Args:
|
||||
batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
|
||||
use_critic: Whether to include critic-specific metrics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
A dictionary of metrics including:
|
||||
- critic/score/mean, max, min: Statistics about sequence scores
|
||||
- critic/rewards/mean, max, min: Statistics about sequence rewards
|
||||
- critic/advantages/mean, max, min: Statistics about advantages
|
||||
- critic/returns/mean, max, min: Statistics about returns
|
||||
- critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
|
||||
- critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
|
||||
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
|
||||
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
|
||||
"""
|
||||
sequence_score = batch.batch["token_level_scores"].sum(-1)
|
||||
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
|
||||
|
||||
@ -119,6 +169,28 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str,
|
||||
|
||||
|
||||
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
|
||||
"""
|
||||
Computes timing metrics for different processing stages in PPO training.
|
||||
|
||||
This function calculates both raw timing metrics (in seconds) and per-token timing metrics
|
||||
(in milliseconds) for various processing stages like generation, reference computation,
|
||||
value computation, advantage computation, and model updates.
|
||||
|
||||
Args:
|
||||
batch: A DataProto object containing batch data with responses and attention masks.
|
||||
timing_raw: A dictionary mapping stage names to their execution times in seconds.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- timing_s/{name}: Raw timing in seconds for each stage
|
||||
- timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage
|
||||
|
||||
Note:
|
||||
Different stages use different token counts for normalization:
|
||||
- "gen" uses only response tokens
|
||||
- Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens
|
||||
(prompt + response)
|
||||
"""
|
||||
response_info = _compute_response_info(batch)
|
||||
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
|
||||
num_response_tokens = torch.sum(response_info["response_length"]).item()
|
||||
@ -136,6 +208,29 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
|
||||
|
||||
|
||||
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Computes throughput metrics for PPO training.
|
||||
|
||||
This function calculates performance metrics related to token processing speed,
|
||||
including the total number of tokens processed, time per step, and throughput
|
||||
(tokens per second per GPU).
|
||||
|
||||
Args:
|
||||
batch: A DataProto object containing batch data with meta information about token counts.
|
||||
timing_raw: A dictionary mapping stage names to their execution times in seconds.
|
||||
Must contain a "step" key with the total step time.
|
||||
n_gpus: Number of GPUs used for training.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- perf/total_num_tokens: Total number of tokens processed in the batch
|
||||
- perf/time_per_step: Time taken for the step in seconds
|
||||
- perf/throughput: Tokens processed per second per GPU
|
||||
|
||||
Note:
|
||||
The throughput is calculated as total_tokens / (time * n_gpus) to normalize
|
||||
across different GPU counts.
|
||||
"""
|
||||
total_num_tokens = sum(batch.meta_info["global_token_num"])
|
||||
time = timing_raw["step"]
|
||||
# estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)
|
||||
@ -155,6 +250,29 @@ def bootstrap_metric(
|
||||
n_bootstrap: int = 1000,
|
||||
seed: int = 42,
|
||||
) -> list[tuple[float, float]]:
|
||||
"""
|
||||
Performs bootstrap resampling to estimate statistics of metrics.
|
||||
|
||||
This function uses bootstrap resampling to estimate the mean and standard deviation
|
||||
of metrics computed by the provided reduction functions on random subsets of the data.
|
||||
|
||||
Args:
|
||||
data: List of data points to bootstrap from.
|
||||
subset_size: Size of each bootstrap sample.
|
||||
reduce_fns: List of functions that compute a metric from a subset of data.
|
||||
n_bootstrap: Number of bootstrap iterations. Defaults to 1000.
|
||||
seed: Random seed for reproducibility. Defaults to 42.
|
||||
|
||||
Returns:
|
||||
A list of tuples, where each tuple contains (mean, std) for a metric
|
||||
corresponding to each reduction function in reduce_fns.
|
||||
|
||||
Example:
|
||||
>>> data = [1, 2, 3, 4, 5]
|
||||
>>> reduce_fns = [np.mean, np.max]
|
||||
>>> bootstrap_metric(data, 3, reduce_fns)
|
||||
[(3.0, 0.5), (4.5, 0.3)] # Example values
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
|
||||
bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]
|
||||
@ -168,7 +286,27 @@ def bootstrap_metric(
|
||||
|
||||
def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float:
|
||||
"""
|
||||
Calculate the majority voting metric
|
||||
Calculate a value based on majority voting.
|
||||
|
||||
This function identifies the most common value for a specified vote key
|
||||
in the data, then returns the corresponding value for that majority vote.
|
||||
|
||||
Args:
|
||||
data: List of dictionaries, where each dictionary contains both vote_key and val_key.
|
||||
vote_key: The key in each dictionary used for voting/counting.
|
||||
val_key: The key in each dictionary whose value will be returned for the majority vote.
|
||||
|
||||
Returns:
|
||||
The value associated with the most common vote.
|
||||
|
||||
Example:
|
||||
>>> data = [
|
||||
... {"pred": "A", "val": 0.9},
|
||||
... {"pred": "B", "val": 0.8},
|
||||
... {"pred": "A", "val": 0.7}
|
||||
... ]
|
||||
>>> calc_maj_val(data, vote_key="pred", val_key="val")
|
||||
0.9 # Returns the first "val" for the majority vote "A"
|
||||
"""
|
||||
vote2vals = defaultdict(list)
|
||||
for d in data:
|
||||
@ -183,15 +321,46 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo
|
||||
|
||||
|
||||
def process_validation_metrics(data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42) -> dict[str, dict[str, dict[str, float]]]:
|
||||
"""Process validation metrics into a structured format.
|
||||
|
||||
"""
|
||||
Process validation metrics into a structured format with statistical analysis.
|
||||
|
||||
This function organizes validation metrics by data source and prompt, then computes
|
||||
various statistical measures including means, standard deviations, best/worst values,
|
||||
and majority voting results. It also performs bootstrap sampling to estimate statistics
|
||||
for different sample sizes.
|
||||
|
||||
Args:
|
||||
data_sources: Array of data source identifiers for each sample
|
||||
sample_inputs: List of input prompts
|
||||
infos_dict: variable name -> list of values for each sample
|
||||
data_sources: List of data source identifiers for each sample.
|
||||
sample_inputs: List of input prompts corresponding to each sample.
|
||||
infos_dict: Dictionary mapping variable names to lists of values for each sample.
|
||||
seed: Random seed for bootstrap sampling. Defaults to 42.
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict[str, float]]]: data source -> variable name -> metric value
|
||||
A nested dictionary with the structure:
|
||||
{
|
||||
data_source: {
|
||||
variable_name: {
|
||||
metric_name: value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Where metric_name includes:
|
||||
- "mean@N": Mean value across N samples
|
||||
- "std@N": Standard deviation across N samples
|
||||
- "best@N/mean": Mean of the best values in bootstrap samples of size N
|
||||
- "best@N/std": Standard deviation of the best values in bootstrap samples
|
||||
- "worst@N/mean": Mean of the worst values in bootstrap samples
|
||||
- "worst@N/std": Standard deviation of the worst values in bootstrap samples
|
||||
- "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
|
||||
- "maj@N/std": Standard deviation of majority voting results (if "pred" exists)
|
||||
|
||||
Example:
|
||||
>>> data_sources = ["source1", "source1", "source2"]
|
||||
>>> sample_inputs = ["prompt1", "prompt1", "prompt2"]
|
||||
>>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
|
||||
>>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict)
|
||||
>>> # result will contain statistics for each data source and variable
|
||||
"""
|
||||
# Group metrics by data source, prompt and variable
|
||||
data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
|
@ -50,10 +50,12 @@ from verl.trainer.ppo.metric_utils import (
|
||||
compute_throughout_metrics,
|
||||
compute_timing_metrics,
|
||||
process_validation_metrics,
|
||||
reduce_metrics,
|
||||
)
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
||||
from verl.utils.metric import (
|
||||
reduce_metrics,
|
||||
)
|
||||
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
|
@ -81,3 +81,27 @@ def load_extern_type(file_path: Optional[str], type_name: Optional[str]):
|
||||
raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.")
|
||||
|
||||
return getattr(module, type_name)
|
||||
|
||||
|
||||
def _get_qualified_name(func):
|
||||
"""Get full qualified name including module and class (if any)."""
|
||||
module = func.__module__
|
||||
qualname = func.__qualname__
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
def deprecated(replacement: str = ""):
|
||||
"""Decorator to mark APIs as deprecated."""
|
||||
import warnings
|
||||
import functools
|
||||
|
||||
def decorator(func):
|
||||
qualified_name = _get_qualified_name(func)
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
msg = f"Warning: API '{qualified_name}' is deprecated."
|
||||
if replacement:
|
||||
msg += f" Please use '{replacement}' instead."
|
||||
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
|
||||
return func(*args, **kwargs)
|
||||
return wrapped
|
||||
return decorator
|
17
verl/utils/metric/__init__.py
Normal file
17
verl/utils/metric/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .utils import reduce_metrics
|
||||
|
||||
__all__ = ["reduce_metrics"]
|
40
verl/utils/metric/utils.py
Normal file
40
verl/utils/metric/utils.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Metrics utils.
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Reduces a dictionary of metric lists by computing the mean of each list.
|
||||
The reduce operation is done via ``np.mean``.
|
||||
|
||||
Args:
|
||||
metrics: A dictionary mapping metric names to lists of metric values.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same keys but with each list replaced by its mean value.
|
||||
|
||||
Example:
|
||||
>>> metrics = {"loss": [1.0, 2.0, 3.0], "accuracy": [0.8, 0.9, 0.7]}
|
||||
>>> reduce_metrics(metrics)
|
||||
{"loss": 2.0, "accuracy": 0.8}
|
||||
"""
|
||||
for key, val in metrics.items():
|
||||
metrics[key] = np.mean(val)
|
||||
return metrics
|
Reference in New Issue
Block a user