[recipe] feat: integrate entropy-mechanism recipe: Clip-Cov and KL-Cov methods (#1830)

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

> Add support for the Clip-Cov and KL-Cov methods in paper: The Entropy
Mechanism of Reinforcement Learning for Reasoning Language Models. Also
add the verifier used in the paper.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.
in `core_algos.py`, we add the clip-cov and kl-cov loss
```
def compute_policy_loss_clip_cov(
    old_log_prob,
    log_prob,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    loss_agg_mode="token-mean",
    clip_ratio=0.0002,
    clip_cov_lb=1.0,
    clip_cov_ub=5.0,
):
    """
    Compute the clipped policy objective and related metrics for Clip-Cov.
    Adapted from
    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
    Args:
        old_log_prob (torch.Tensor):
            Log-probabilities of actions under the old policy, shape (batch_size, response_length).
        log_prob (torch.Tensor):
            Log-probabilities of actions under the current policy, shape (batch_size, response_length).
        advantages (torch.Tensor):
            Advantage estimates for each action, shape (batch_size, response_length).
        response_mask (torch.Tensor):
            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
        cliprange (float, optional):
            Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
            Defaults to None (must be provided).
        cliprange_low (float, optional):
            Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
        cliprange_high (float, optional):
            Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
        loss_agg_mode (str, optional):
            Aggregation mode for `agg_loss`. Defaults to "token-mean".
        clip_ratio (float, optional):
            Ratio for clipping the covariance. Defaults to 0.0002.
        clip_cov_lb (float, optional):
            Lower bound for clipping covariance. Defaults to 1.0.
        clip_cov_ub (float, optional):
            Upper bound for clipping covariance. Defaults to 5.0.
    """
    assert clip_ratio > 0, "clip_ratio should be larger than 0."
    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio

    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange

    corr = torch.ones_like(advantages)
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
    clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)

    cov_all = (advantages- verl_F.masked_mean(advantages, response_mask)) * (log_prob- verl_F.masked_mean(log_prob.detach(), response_mask))
    cov_all[response_mask == 0] = -torch.inf
    cov_all[clip_by_origin] = -torch.inf

    clip_num = max(int(clip_ratio * response_mask.sum().item()), 1)
    top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)
    top_k_idx = torch.nonzero(top_k_idx)

    if len(top_k_idx) > 0:
        perm = torch.randperm(len(top_k_idx))
        top_k_idx = top_k_idx[perm[:min(clip_num, len(top_k_idx))]]
    else:
        top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)

    corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0

    pg_clipfrac = verl_F.masked_mean((corr==0).float(), response_mask)

    pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.)


def compute_policy_loss_kl_cov(
    old_log_prob,
    log_prob,
    advantages,
    response_mask,
    loss_agg_mode="token-mean",
    k_ratio=0.0002,
    ppo_kl_coef=1,
):
    """
    Compute the clipped policy objective and related metrics for Clip-Cov.
    Adapted from
    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
    Args:
        old_log_prob (torch.Tensor):
            Log-probabilities of actions under the old policy, shape (batch_size, response_length).
        log_prob (torch.Tensor):
            Log-probabilities of actions under the current policy, shape (batch_size, response_length).
        advantages (torch.Tensor):
            Advantage estimates for each action, shape (batch_size, response_length).
        response_mask (torch.Tensor):
            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
        loss_agg_mode (str, optional):
            Aggregation mode for `agg_loss`. Defaults to "token-mean".
        k_ratio (float, optional):
            Ratio for selecting the top-k covariance values. Defaults to 0.0002.
        ppo_kl_coef (float, optional):
            Coefficient for the KL penalty term in the loss. Defaults to 1.
    """
    assert k_ratio > 0, "k_ratio should be larger than 0."
    negative_approx_kl = log_prob - old_log_prob
    abs_kl = negative_approx_kl.abs()
    ratio = torch.exp(negative_approx_kl)
    ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)
    pg_losses1 = -advantages * ratio
    pg_losses_kl = - advantages * ratio + ppo_kl_coef * abs_kl
    pg_losses = pg_losses1

    all_valid = (response_mask > 0)
    all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] 
    all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()
    all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()

    k = min(k_ratio, len(all_valid_adv))

    if k != 0:
        cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())
        k_percent_nums = max(1, int(len(cov_lst_all) * k_ratio))
        large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices

        if len(large_cov_idxs) != 0:
            large_cov_idxs = all_valid_idx[large_cov_idxs]
            pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]]

    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, torch.tensor(0.), ppo_kl_abs, torch.tensor(0.)

```

in the `dp_actor.py`, we add the loss mode switch feature:
```
                    loss_mode = self.config.get("loss_mode", "vanilla")
                    if loss_mode not in ["vanilla", "clip_cov", "kl_cov"]:
                        raise ValueError(f"Unsupported loss mode: {loss_mode}. Supported modes are: 'vanilla', 'clip_cov', 'kl_cov'.")

                    if loss_mode == "vanilla":
                        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            cliprange=clip_ratio,
                            cliprange_low=clip_ratio_low,
                            cliprange_high=clip_ratio_high,
                            clip_ratio_c=clip_ratio_c,
                            loss_agg_mode=loss_agg_mode,
                        )

                    elif loss_mode == "clip_cov":
                        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower= compute_policy_loss_clip_cov(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            cliprange=clip_ratio,
                            cliprange_low=clip_ratio_low,
                            cliprange_high=clip_ratio_high,
                            loss_agg_mode=loss_agg_mode,
                            clip_ratio=self.config.clip_cov_ratio,
                            clip_cov_lb=self.config.clip_cov_lb,
                            clip_cov_ub=self.config.clip_cov_ub,
                        )

                    elif loss_mode == "kl_cov":
                        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower= compute_policy_loss_kl_cov(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            loss_agg_mode=loss_agg_mode,
                            k_ratio=self.config.k_ratio,
                            ppo_kl_coef=self.config.ppo_kl_coef,
                        )
```


### Usage Example

> Provide usage example(s) for easier usage.

We create a recipe (built on dapo recipe) named entropy to store our
scripts, for example the `7b_kl_cov.sh`:

```

#!/usr/bin/env bash
set -xeuo pipefail

export WANDB_API_KEY=YOUR_WANDB_API_KEY
# export VLLM_USE_V1=1

project_name='Qwen2.5-7B'
exp_name='klcov'

adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.2

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=False
overlong_buffer_len=$((1024 * 2))
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"
loss_mode="kl_cov"
enable_filter_groups=False
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=256
gen_prompt_bsz=$((train_prompt_bsz * 3))
train_prompt_mini_bsz=256
n_resp_per_prompt=8
max_token=20480

# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
ppo_kl_coef=1
k_ratio=0.002

# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False

HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${TEST_FILE}" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.filter_overlong_prompts=False \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.gen_batch_size=${gen_prompt_bsz} \
    data.train_batch_size=${train_prompt_bsz} \
    data.return_raw_chat=True \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.actor.loss_mode=${loss_mode} \
    actor_rollout_ref.actor.k_ratio=${k_ratio} \
    actor_rollout_ref.actor.ppo_kl_coef=${ppo_kl_coef} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.rollout.mode=sync \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    algorithm.filter_groups.enable=${enable_filter_groups} \
    algorithm.filter_groups.metric=${filter_groups_metric} \
    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.weight_decay=0 \
    actor_rollout_ref.actor.optim.warmup_style=constant \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k="${top_k}" \
    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=False \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
    reward_model.reward_manager=dapo \
    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
    reward_model.overlong_buffer.len=${overlong_buffer_len} \
    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
    trainer.logger=['console','wandb'] \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes="${NNODES}" \
    trainer.val_before_train=False \
    trainer.test_freq=4 \
    trainer.save_freq=32 \
    trainer.total_epochs=1000 \
    trainer.default_local_dir="${CKPTS_DIR}" \
    trainer.resume_mode=disable

```

### Test

Please refer to the Fig 11 and Tab 2 in https://arxiv.org/pdf/2505.22617
for detailed results.

### Additional Info.

NA

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if necessary.

---------

Co-authored-by: Jiacheng Chen <jackchan9345@gmail.com>
Co-authored-by: H <linhaibin.eric@gmail.com>
This commit is contained in:
Yuchen Zhang
2025-06-20 06:08:43 +08:00
committed by GitHub
parent ba908710ff
commit 39b7250b0a
22 changed files with 3356 additions and 26 deletions

View File

@ -217,6 +217,7 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The
- [LUFFY](https://arxiv.org/pdf/2504.14945): Learning to Reason under Off-Policy Guidance![GitHub Repo stars](https://img.shields.io/github/stars/ElliottYan/LUFFY)
- [verl-tool](https://github.com/TIGER-AI-Lab/verl-tool): An unified and easy-to-extend tool-agent training framework based on verl![GitHub Repo stars](https://img.shields.io/github/stars/TIGER-AI-Lab/verl-tool)
- [DeepMath](https://github.com/zwhe99/DeepMath): DeepMath-103K data and series models for math reasoning![GitHub Repo stars](https://img.shields.io/github/stars/zwhe99/DeepMath)
- [Entropy Mechanism of RL](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL): The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning![GitHub Repo stars](https://img.shields.io/github/stars/PRIME-RL/Entropy-Mechanism-of-RL)
- [LLaSA-TTS-GRPO](https://github.com/channel-io/ch-tts-llasa-rl-grpo): TTS fine-tuning with GRPO optimization based on LLASA models ![GitHub Repo stars](https://img.shields.io/github/stars/channel-io/ch-tts-llasa-rl-grpo)
- [RL-Factory](https://github.com/Simple-Efficient/RL-Factory): An easy and efficient RL post-training framework for Agentic Learning ![GitHub Repo stars](https://img.shields.io/github/stars/Simple-Efficient/RL-Factory)
- [RACRO](https://github.com/gyhdog99/RACRO2): Build multi-modal reasoning models via decoupling it into query-conditioned captioning and text-only reasoning ![GitHub Repo stars](https://img.shields.io/github/stars/gyhdog99/RACRO2)

114
docs/algo/entropy.md Normal file
View File

@ -0,0 +1,114 @@
# Recipe: Entropy Mechanism
<div align="center">
The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.
[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617) [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue
)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861)
<div align="center" style="font-family: Arial, sans-serif;">
<p>
<a href="#🎉news" style="text-decoration: none; font-weight: bold;">🎉 News</a>
<a href="#✨getting-started" style="text-decoration: none; font-weight: bold;">✨ Getting Started</a>
<a href="#📖introduction" style="text-decoration: none; font-weight: bold;">📖 Introduction</a>
</p>
<p>
<a href="#🎈citation" style="text-decoration: none; font-weight: bold;">🎈 Citation</a>
<a href="#🌻acknowledgement" style="text-decoration: none; font-weight: bold;">🌻 Acknowledgement</a>
<a href="#📬Contact" style="text-decoration: none; font-weight: bold;">📬 Contact</a>
<a href="#📈star-history" style="text-decoration: none; font-weight: bold;">📈 Star History</a>
</p>
</div>
</div>
# 🎉News
- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).
- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse.
# ✨Getting started
After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:
```
cd verl
conda activate your_env
bash recipe/dapo/7b_kl_cov.sh
```
While for training Qwen2.5-32B on multi nodes, you can run the following commands:
```
cd verl
conda activate your_env
bash recipe/dapo/32b_kl_cov.sh
```
# 📖Introduction
<div align="left">
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
</div>
This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion.
<div align="left">
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
</div>
Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropys monotonic decline. To mitigate this, we propose Clip-Cov and KL-Cov, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance.
# 📃Evaluation
<div align="left">
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
</div>
Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL.
| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |
| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |
| *Qwen2.5-7B* | | | | | | | | |
| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 |
| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 |
| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 |
| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** |
| *Qwen2.5-32B* | | | | | | | | |
| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 |
| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 |
| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 |
| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** |
Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.
# 🎈Citation
If you find this paper or repo helpful, please cite us.
```bibtex
@article{cui2025entropy,
title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},
author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},
journal={arXiv preprint arXiv:2505.22617},
year={2025}
}
```
# 🌻Acknowledgement
We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!
# 📬 Contact
For questions, discussion, or collaboration opportunities, feel free to contact:
- Ganqu Cui: cuiganqu@pjlab.org.cn
- Yuchen Zhang: yuchen.zhang2003@gmail.com
- Jiacheng Chen: jackchan9345@gmail.com
- Ning Ding: ningding.cs@gmail.com

View File

@ -70,6 +70,7 @@ verl is fast with:
algo/dapo.md
algo/spin.md
algo/sppo.md
algo/entropy.md
algo/opo.md
algo/baseline.md

View File

@ -0,0 +1,148 @@
#!/usr/bin/env bash
set -xeuo pipefail
export WANDB_API_KEY=YOUR_WANDB_API_KEY
# export VLLM_USE_V1=1
project_name='Qwen2.5-32B'
exp_name='clipcov'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=1
clip_ratio_high=1
clip_cov_ratio=0.0002
clip_cov_lb=1.0
clip_cov_ub=5.0
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=False
overlong_buffer_len=$((1024 * 2))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
loss_mode="clip_cov"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=256
gen_prompt_bsz=$((train_prompt_bsz * 3))
train_prompt_mini_bsz=32
n_resp_per_prompt=8
max_token=20480
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
ppo_kl_coef=1
kl_cov_ratio=0.02
# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.filter_overlong_prompts=False \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.return_raw_chat=True \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \
actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \
actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.mode=sync \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.metric=${filter_groups_metric} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.clip_cov_ratio=${clip_cov_ratio} \
actor_rollout_ref.actor.clip_cov_lb=${clip_cov_lb} \
actor_rollout_ref.actor.clip_cov_ub=${clip_cov_ub} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=False \
trainer.test_freq=4 \
trainer.save_freq=32 \
trainer.total_epochs=1000 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=disable

View File

@ -0,0 +1,142 @@
#!/usr/bin/env bash
set -xeuo pipefail
export WANDB_API_KEY=YOUR_WANDB_API_KEY
# export VLLM_USE_V1=1
project_name='Qwen2.5-32B'
exp_name='klcov'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.2
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=False
overlong_buffer_len=$((1024 * 2))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
loss_mode="kl_cov"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=256
gen_prompt_bsz=$((train_prompt_bsz * 3))
train_prompt_mini_bsz=32
n_resp_per_prompt=8
max_token=20480
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
ppo_kl_coef=1
kl_cov_ratio=0.0002
# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.filter_overlong_prompts=False \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.return_raw_chat=True \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.loss_mode=${loss_mode} \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \
actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.mode=sync \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.metric=${filter_groups_metric} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=False \
trainer.test_freq=4 \
trainer.save_freq=32 \
trainer.total_epochs=1000 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=disable

View File

@ -0,0 +1,141 @@
#!/usr/bin/env bash
set -xeuo pipefail
export WANDB_API_KEY=YOUR_WANDB_API_KEY
# export VLLM_USE_V1=1
project_name='Qwen2.5-32B'
exp_name='klcov'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.2
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=False
overlong_buffer_len=$((1024 * 2))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
loss_mode="kl_cov"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=256
gen_prompt_bsz=$((train_prompt_bsz * 3))
train_prompt_mini_bsz=16
n_resp_per_prompt=8
max_token=20480
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
ppo_kl_coef=1
kl_cov_ratio=0.0002
# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.filter_overlong_prompts=False \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.return_raw_chat=True \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \
actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.mode=sync \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.metric=${filter_groups_metric} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=False \
trainer.test_freq=4 \
trainer.save_freq=32 \
trainer.total_epochs=1000 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=disable

View File

@ -0,0 +1,145 @@
#!/usr/bin/env bash
set -xeuo pipefail
export WANDB_API_KEY=YOUR_WANDB_API_KEY
# export VLLM_USE_V1=1
project_name='Qwen2.5-7B'
exp_name='clipcov'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=1
clip_ratio_high=1
clip_cov_ratio=0.0002
clip_cov_lb=1.0
clip_cov_ub=5.0
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=False
overlong_buffer_len=$((1024 * 2))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
loss_mode="clip_cov"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=256
gen_prompt_bsz=$((train_prompt_bsz * 3))
train_prompt_mini_bsz=32
n_resp_per_prompt=8
max_token=30720
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
ppo_kl_coef=1
kl_cov_ratio=0.2
# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.filter_overlong_prompts=False \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.return_raw_chat=True \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \
actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \
actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.mode=sync \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.metric=${filter_groups_metric} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=False \
trainer.test_freq=4 \
trainer.save_freq=32 \
trainer.total_epochs=1000 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=disable

141
recipe/entropy/7b_kl_cov.sh Normal file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env bash
set -xeuo pipefail
export WANDB_API_KEY=YOUR_WANDB_API_KEY
# export VLLM_USE_V1=1
project_name='Qwen2.5-7B'
exp_name='klcov'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.2
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=False
overlong_buffer_len=$((1024 * 2))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
loss_mode="kl_cov"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=256
gen_prompt_bsz=$((train_prompt_bsz * 3))
train_prompt_mini_bsz=32
n_resp_per_prompt=8
max_token=30720
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
ppo_kl_coef=1
kl_cov_ratio=0.002
# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.filter_overlong_prompts=False \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.return_raw_chat=True \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \
actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.mode=sync \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.metric=${filter_groups_metric} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=False \
trainer.test_freq=4 \
trainer.save_freq=32 \
trainer.total_epochs=1000 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=disable

110
recipe/entropy/README.md Normal file
View File

@ -0,0 +1,110 @@
<div align="center">
# The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.
[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617) [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue
)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861)
<div align="center" style="font-family: Arial, sans-serif;">
<p>
<a href="#🎉news" style="text-decoration: none; font-weight: bold;">🎉 News</a>
<a href="#✨getting-started" style="text-decoration: none; font-weight: bold;">✨ Getting Started</a>
<a href="#📖introduction" style="text-decoration: none; font-weight: bold;">📖 Introduction</a>
</p>
<p>
<a href="#🎈citation" style="text-decoration: none; font-weight: bold;">🎈 Citation</a>
<a href="#🌻acknowledgement" style="text-decoration: none; font-weight: bold;">🌻 Acknowledgement</a>
<a href="#📬Contact" style="text-decoration: none; font-weight: bold;">📬 Contact</a>
<a href="#📈star-history" style="text-decoration: none; font-weight: bold;">📈 Star History</a>
</p>
</div>
</div>
# 🎉News
- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).
- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse.
# ✨Getting started
After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:
```
cd verl
conda activate your_env
bash recipe/dapo/7b_kl_cov.sh
```
While for training Qwen2.5-32B on multi nodes, you can run the following commands:
```
cd verl
conda activate your_env
bash recipe/dapo/32b_kl_cov.sh
```
# 📖Introduction
<div align="left">
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
</div>
This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion.
<div align="left">
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
</div>
Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropys monotonic decline. To mitigate this, we propose Clip-Cov and KL-Cov, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance.
# 📃Evaluation
<div align="left">
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
</div>
Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL.
| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |
| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |
| *Qwen2.5-7B* | | | | | | | | |
| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 |
| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 |
| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 |
| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** |
| *Qwen2.5-32B* | | | | | | | | |
| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 |
| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 |
| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 |
| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** |
Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.
# 🎈Citation
If you find this paper or repo helpful, please cite us.
```bibtex
@article{cui2025entropy,
title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},
author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},
journal={arXiv preprint arXiv:2505.22617},
year={2025}
}
```
# 🌻Acknowledgement
We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!
# 📬 Contact
For questions, discussion, or collaboration opportunities, feel free to contact:
- Ganqu Cui: cuiganqu@pjlab.org.cn
- Yuchen Zhang: yuchen.zhang2003@gmail.com
- Jiacheng Chen: jackchan9345@gmail.com
- Ning Ding: ningding.cs@gmail.com

View File

@ -0,0 +1,39 @@
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
data:
gen_batch_size: ${data.train_batch_size}
reward_model:
reward_kwargs:
overlong_buffer_cfg: $reward_model.overlong_buffer
reward_manager: dapo
overlong_buffer:
enable: False
len: 0
penalty_factor: 0.0
log: False
algorithm:
filter_groups:
enable: False # We try to avoid forgetting to set enable
metric: null # acc / score / seq_reward / seq_final_reward / ...
max_num_gen_batches: 0 # Non-positive values mean no upper limit
trainer:
project_name: verl-entropy
actor_rollout_ref:
actor:
policy_loss:
loss_mode: "vanilla" # /clip-cov / kl-cov from https://arxiv.org/abs/2505.
clip_cov_ratio: 0.0002 # for clip-cov loss
clip_cov_lb: 1.0 # for clip-cov loss
clip_cov_ub: 5.0 # for clip-cov loss
kl_cov_ratio: 0.0002 # for kl-cov loss
ppo_kl_coef: 0.1 # for kl-cov loss

View File

@ -0,0 +1,311 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from tqdm import tqdm
from verl import DataProto
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
from verl.utils.debug import simple_timer
class RayEntropyTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
timing_raw = defaultdict(float)
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
# pop those keys for generation
if "multi_modal_inputs" in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
)
else:
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
is_last_step = self.global_steps >= self.total_training_steps
with simple_timer("step", timing_raw):
# generate a batch
# with simple_timer("gen", timing_raw):
# gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
with simple_timer("gen", timing_raw):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
self.async_rollout_manager.wake_up()
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
self.async_rollout_manager.sleep()
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with simple_timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
new_batch = new_batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(new_batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
new_batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)
with simple_timer("reward", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(reward_tensor)
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
try:
reward_result = self.reward_fn(new_batch, return_dict=True)
reward_tensor = reward_result["reward_tensor"]
reward_extra_infos_dict = reward_result["reward_extra_info"]
except Exception as e:
print(f"Error in reward_fn: {e}")
reward_tensor = self.reward_fn(new_batch)
reward_extra_infos_dict = {}
new_batch.batch["token_level_scores"] = reward_tensor
print(f"{list(reward_extra_infos_dict.keys())=}")
if reward_extra_infos_dict:
new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
# we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]):
prompt_uid2metric_vals[uid].append(metric_val)
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1]
num_prompt_in_batch += len(kept_prompt_uids)
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
new_batch = new_batch[kept_traj_idxs]
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
continue
else:
raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.")
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
print(f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. Collecting finished.")
batch = batch[:traj_bsz]
# === Updating ===
batch.batch["response_mask"] = compute_response_mask(batch)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with simple_timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with simple_timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with simple_timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with simple_timer("adv", timing_raw):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
# update critic
if self.use_critic:
with simple_timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with simple_timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with simple_timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
with simple_timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing
metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1

View File

@ -0,0 +1,227 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import hydra
import ray
from verl.trainer.ppo.reward import load_reward_manager
from .entropy_ray_trainer import RayEntropyTrainer
@hydra.main(config_path="config", config_name="entropy_trainer", version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "WANDB_API_KEY": "YOUR_WANDB_API_KEY"}},
num_cpus=config.ray_init.num_cpus,
)
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
def merge_dict(a: dict, b: dict) -> dict:
"""Return a new dict that has `a` updated with `b` (b wins on conflicts).
Example::
>>> d1 = {"x": 1, "y": 2}
>>> d2 = {"y": 20, "z": 3}
>>> new_dict = merge_dict(d1, d2)
>>> print(new_dict) # {'x': 1, 'y': 20, 'z': 3}
>>> print(d1) # {"x": 1, "y": 2} (unchanged)
>>> print(d2) # {"y": 20, "z": 3} (unchanged)
"""
return a | b
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
print(f"{config.actor_rollout_ref.model.path}")
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
assert config.critic.strategy in ["fsdp", "fsdp2"]
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
actor_rollout_cls = ActorRolloutRefWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(actor_rollout_cls),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in ["fsdp", "fsdp2"]:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
reward_kwargs = {"max_resp_len": config.data.max_response_length, "overlong_buffer_cfg": config.reward_model.overlong_buffer}
cfg_reward_kwargs = config.reward_model.get("reward_kwargs", {})
reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **(merge_dict(reward_kwargs, cfg_reward_kwargs)))
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
from verl.utils.dataset.rl_dataset import collate_fn
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
trainer = RayEntropyTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
)
trainer.init_workers()
trainer.fit()
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
"""Create a dataset.
Arguments:
data_config: The data config.
tokenizer (Tokenizer): The tokenizer.
processor (Processor): The processor.
Returns:
dataset (Dataset): The dataset.
"""
from torch.utils.data import Dataset
from verl.utils.dataset.rl_dataset import RLHFDataset
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
from verl.utils.import_utils import load_extern_type
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset")
else:
dataset_cls = RLHFDataset
print(f"Using dataset class: {dataset_cls.__name__}")
dataset = dataset_cls(
data_files=data_paths,
tokenizer=tokenizer,
processor=processor,
config=data_config,
)
return dataset
def create_rl_sampler(data_config, dataset):
"""Create a sampler for the dataset.
Arguments:
data_config: The data config.
dataset (Dataset): The dataset.
Returns:
sampler (Sampler): The sampler.
"""
import torch
from torch.utils.data import RandomSampler, SequentialSampler
# use sampler for better ckpt resume
if data_config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=dataset)
return sampler
if __name__ == "__main__":
main()

View File

@ -62,6 +62,13 @@ actor_rollout_ref:
ppo_epochs: 1
data_loader_seed: null
shuffle: False
policy_loss: # policy loss config
loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov from https://arxiv.org/abs/2505.22617
clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss
clip_cov_lb: 1.0 # Lower bound for clip-cov loss
clip_cov_ub: 5.0 # Upper bound for clip-cov loss
kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss
ppo_kl_coef: 0.1 # KL divergence penalty coefficient
optim:
optimizer: adam
lr: 1e-6

View File

@ -177,6 +177,27 @@ actor_rollout_ref:
# Upper bound for asymmetric clipping (used in dual-clip PPO)
clip_ratio_high: 0.2
# policy loss config
policy_loss:
# Loss function mode: vanilla / clip-cov / kl-cov from https://arxiv.org/abs/2505.22617
loss_mode: "vanilla"
# Ratio of tokens to be clipped for clip-cov loss
clip_cov_ratio: 0.0002
# Lower bound for clip-cov loss
clip_cov_lb: 1.0
# Upper bound for clip-cov loss
clip_cov_ub: 5.0
# Ratio of tokens to be applied kl penalty for kl-cov loss
kl_cov_ratio: 0.0002
# KL divergence penalty coefficient
ppo_kl_coef: 0.1
# Constant C in Dual-clip PPO; clips when advantage < -C
clip_ratio_c: 3.0

View File

@ -28,6 +28,33 @@ import torch
import verl.utils.torch_functional as verl_F
POLICY_LOSS_REGISTRY = {}
def register_policy_loss(name):
def decorator(func):
POLICY_LOSS_REGISTRY[name] = func
return func
return decorator
def get_policy_loss_fn(name):
"""Get the policy loss with a given name.
Args:
name: `(str)`
The name of the policy loss.
Returns:
`(callable)`: The policy loss function.
"""
loss_name = name
if loss_name not in POLICY_LOSS_REGISTRY:
raise ValueError(f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}")
return POLICY_LOSS_REGISTRY[loss_name]
ADV_ESTIMATOR_REGISTRY = {}
@ -596,6 +623,159 @@ def compute_policy_loss(
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("clip_cov")
def compute_policy_loss_clip_cov(
old_log_prob,
log_prob,
advantages,
response_mask,
loss_agg_mode="token-mean",
config=None,
):
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
clip_cvo_ratio (float, optional):
Ratio for clipping the covariance. Defaults to 0.0002.
clip_cov_lb (float, optional):
Lower bound for clipping covariance. Defaults to 1.0.
clip_cov_ub (float, optional):
Upper bound for clipping covariance. Defaults to 5.0.
"""
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
cliprange = config.clip_ratio
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange
clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0
clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0
assert clip_cov_ratio > 0, "clip_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
corr = torch.ones_like(advantages)
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)
cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * (log_prob - verl_F.masked_mean(log_prob.detach(), response_mask))
cov_all[response_mask == 0] = -torch.inf
cov_all[clip_by_origin] = -torch.inf
clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)
top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)
top_k_idx = torch.nonzero(top_k_idx)
if len(top_k_idx) > 0:
perm = torch.randperm(len(top_k_idx))
top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]
else:
top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)
corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
@register_policy_loss("kl_cov")
def compute_policy_loss_kl_cov(
old_log_prob,
log_prob,
advantages,
response_mask,
loss_agg_mode="token-mean",
config=None,
):
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
kl_cov_ratio (float, optional):
Ratio for selecting the top-k covariance values. Defaults to 0.0002.
ppo_kl_coef (float, optional):
Coefficient for the KL penalty term in the loss. Defaults to 1.
"""
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0
assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
abs_kl = negative_approx_kl.abs()
ratio = torch.exp(negative_approx_kl)
ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)
pg_losses1 = -advantages * ratio
pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl
pg_losses = pg_losses1
all_valid = response_mask > 0
all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]
all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()
all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()
k = min(kl_cov_ratio, len(all_valid_adv))
if k != 0:
cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())
k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))
large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices
if len(large_cov_idxs) != 0:
large_cov_idxs = all_valid_idx[large_cov_idxs]
pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]]
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
"""Compute categorical entropy loss (For backward compatibility)

View File

@ -655,9 +655,11 @@ class RayPPOTrainer:
sample_scores.extend(scores)
reward_extra_infos_dict["reward"].extend(scores)
print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}")
if "reward_extra_info" in result:
for key, lst in result["reward_extra_info"].items():
reward_extra_infos_dict[key].extend(lst)
print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}")
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))

View File

@ -83,6 +83,7 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No
from . import search_r1_like_qa_em
res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError(f"Reward function is not implemented for {data_source=}")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,342 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Microsoft Corporation.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
# Copyright (c) 2023 OpenAI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) 2021 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
- https://github.com/openai/prm800k
"""
import contextlib
import math
import re
from math import isclose
from typing import Union
# sympy related
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
# verl related
from verl.utils.py_functional import timeout_limit
def is_digit(s):
try:
if "{,}" in str(s):
num = float(str(s).replace("{,}", ""))
return True, num
num = float(str(s).replace(",", ""))
return True, num
except ValueError:
return False, None
def normalize(answer, pi) -> str:
# checking if answer is $<number> and removing $ in that case to compare
if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)):
return answer[1:]
# checking if answer is <number>% or <number>\\% and removing %
if isinstance(answer, str) and (bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))):
return answer.replace("\\%", "").replace("%", "")
# handle base
answer = handle_base(answer)
# handle pi
answer = handle_pi(answer, pi)
return answer
def handle_base(x) -> str:
if isinstance(x, str) and "_" in x:
# Due to base
x = x.split("_")[0]
x = float(x)
return int(x)
return x
def handle_pi(string, pi):
if isinstance(string, str) and "\pi" in string:
# Find the first occurrence of "\pi"
idx = string.find("\pi")
# Iterate over the string and find all occurrences of "\pi" with a valid previous character
while idx != -1:
if idx > 0 and string[idx - 1].isdigit():
# Replace "\pi" with "*math.pi" if the previous character is a digit
string = string[:idx] + f"*{pi}" + string[idx + 3 :]
else:
# Replace "\pi" with "1*math.pi" if the previous character is not a digit
string = string[:idx] + f"1*{pi}" + string[idx + 3 :]
# Find the next occurrence of "\pi"
idx = string.find("\pi", idx + 1)
# Evaluate the expression using eval() function
with contextlib.suppress(Exception):
string = eval(string)
return string
def math_equal(
prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
tolerance: float = 1e-4,
timeout: float = 10.0,
pi: float = math.pi,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
prediction = normalize(prediction, pi)
reference = normalize(reference, pi)
if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
prediction = prediction[:1000]
# 0. string comparison
if isinstance(prediction, str) and isinstance(reference, str):
if prediction.strip().lower() == reference.strip().lower():
return True
if prediction.replace(" ", "") == reference.replace(" ", ""):
return True
try: # 1. numerical equal
if is_digit(prediction)[0] and is_digit(reference)[0]:
prediction = is_digit(prediction)[1]
reference = is_digit(reference)[1]
# number questions
gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
for item in gt_result:
try:
if isclose(item, prediction, rel_tol=tolerance):
return True
except Exception:
continue
return False
except Exception:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## deal with [], (), {}
prediction = format_intervals(prediction)
pred_str, ref_str = prediction, reference
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
pred_str = pred_str.strip("[]()")
ref_str = ref_str.strip("[]()")
for s in ["{", "}", "(", ")"]:
ref_str = ref_str.replace(s, "")
pred_str = pred_str.replace(s, "")
if pred_str == ref_str:
return True
## [a, b] vs. [c, d], return a==c and b==d
if prediction and reference and prediction[0] in "([" and prediction[-1] in ")]" and prediction[0] == reference[0] and prediction[-1] == reference[-1]:
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]):
return True
if "," in prediction and "," in reference:
pred_parts = [item.strip() for item in prediction.split(",")]
ref_parts = [item.strip() for item in reference.split(",")]
if len(pred_parts) == len(ref_parts):
return bool(all([math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) for i in range(len(pred_parts))]))
# if we have point == tuple of values
if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]):
return True
# if reference is a matrix
if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"):
try:
pred_matrix = parse_expr(prediction)
ref_matrix_items = reference.split()[1:-1:2]
if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]):
return True
except Exception:
pass
elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"):
if isinstance(eval(prediction), list):
try:
pred_matrix = eval(prediction)
# ref_matrix_items = reference.split()[1:-1:2]
ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip("\\end{pmatrix}").rstrip("\end{pmatrix}") # noqa: B005
ref_matrix_items = ref_matrix_items.split("\\")
ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items]
if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]):
return True
except Exception:
pass
return symbolic_equal(prediction, reference, tolerance, timeout)
def symbolic_equal(a, b, tolerance, timeout=10.0):
def _parse(s):
for f in [parse_expr, parse_latex]:
try:
with timeout_limit(seconds=timeout):
return f(s)
except TimeoutError:
print(f"Parsing timed out for {s}")
continue
except Exception:
continue
return s
a = _parse(a)
b = _parse(b)
try:
with timeout_limit(seconds=timeout):
if simplify(a - b) == 0:
return True
except TimeoutError:
print(f"Simplification timed out for {a} - {b}")
pass
except Exception:
pass
try:
with timeout_limit(seconds=timeout):
if isclose(N(a), N(b), rel_tol=tolerance):
return True
except TimeoutError:
print(f"Numerical evaluation timed out for {a}, {b}")
pass
except Exception:
pass
return False
def format_intervals(prediction):
patterns = {
"Interval(": r"^Interval\((.*)\)$",
"Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
"Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
"Interval.open(": r"^Interval\.open\((.*)\)$",
}
for key, pattern in patterns.items():
match = re.match(pattern, prediction)
if match:
inner_content = match.group(1)
if key == "Interval(": # Intarval(a, b) == [a, b]
return f"[{inner_content}]"
elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
return f"[{inner_content})"
elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
return f"({inner_content}]"
elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
return f"({inner_content})"
return prediction

View File

@ -0,0 +1,191 @@
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py
"""
import re
from typing import Optional
def normalize_answer(answer: Optional[str]) -> Optional[str]:
if answer is None:
return None
answer = answer.strip()
try:
# Remove enclosing `\text{}`.
m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
if m is not None:
answer = m.group("text").strip()
return _strip_string(answer)
except: # noqa: E722
return answer
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except: # noqa: E722
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except: # noqa: E722
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string

View File

@ -28,7 +28,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
from verl.utils.debug import GPUMemoryLogger
from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
@ -399,17 +399,23 @@ class DataParallelPPOActor(BasePPOActor):
calculate_entropy = True
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
if self.config.policy_loss.loss_mode == "vanilla":
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)
else:
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config)
if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

View File

@ -38,7 +38,7 @@ from omegaconf import OmegaConf
from torch import nn
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
from verl.utils.debug import GPUMemoryLogger
from verl.utils.debug.profile import Profiler
from verl.utils.device import get_device_id, get_torch_device
@ -352,21 +352,30 @@ class MegatronPPOActor(BasePPOActor):
old_log_prob = data["old_log_probs"]
advantages = data["advantages"]
clip_ratio = meta_info["clip_ratio"]
clip_ratio = self.config.clip_ratio
clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = meta_info["clip_ratio_c"]
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode
loss_mode = self.config.get("loss_mode", "vanilla")
if self.config.policy_loss.loss_mode == "vanilla":
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)
else:
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config)
policy_loss = pg_loss
if calculate_entropy:
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()