mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[recipe] feat: integrate entropy-mechanism recipe: Clip-Cov and KL-Cov methods (#1830)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? > Add support for the Clip-Cov and KL-Cov methods in paper: The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models. Also add the verifier used in the paper. ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. in `core_algos.py`, we add the clip-cov and kl-cov loss ``` def compute_policy_loss_clip_cov( old_log_prob, log_prob, advantages, response_mask, cliprange=None, cliprange_low=None, cliprange_high=None, loss_agg_mode="token-mean", clip_ratio=0.0002, clip_cov_lb=1.0, clip_cov_ub=5.0, ): """ Compute the clipped policy objective and related metrics for Clip-Cov. Adapted from https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py Args: old_log_prob (torch.Tensor): Log-probabilities of actions under the old policy, shape (batch_size, response_length). log_prob (torch.Tensor): Log-probabilities of actions under the current policy, shape (batch_size, response_length). advantages (torch.Tensor): Advantage estimates for each action, shape (batch_size, response_length). response_mask (torch.Tensor): Mask indicating which tokens to include in the loss, shape (batch_size, response_length). cliprange (float, optional): Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. Defaults to None (must be provided). cliprange_low (float, optional): Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. cliprange_high (float, optional): Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". clip_ratio (float, optional): Ratio for clipping the covariance. Defaults to 0.0002. clip_cov_lb (float, optional): Lower bound for clipping covariance. Defaults to 1.0. clip_cov_ub (float, optional): Upper bound for clipping covariance. Defaults to 5.0. """ assert clip_ratio > 0, "clip_ratio should be larger than 0." negative_approx_kl = log_prob - old_log_prob ratio = torch.exp(negative_approx_kl) ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) pg_losses1 = -advantages * ratio if cliprange_low is None: cliprange_low = cliprange if cliprange_high is None: cliprange_high = cliprange corr = torch.ones_like(advantages) pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) cov_all = (advantages- verl_F.masked_mean(advantages, response_mask)) * (log_prob- verl_F.masked_mean(log_prob.detach(), response_mask)) cov_all[response_mask == 0] = -torch.inf cov_all[clip_by_origin] = -torch.inf clip_num = max(int(clip_ratio * response_mask.sum().item()), 1) top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) top_k_idx = torch.nonzero(top_k_idx) if len(top_k_idx) > 0: perm = torch.randperm(len(top_k_idx)) top_k_idx = top_k_idx[perm[:min(clip_num, len(top_k_idx))]] else: top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 pg_clipfrac = verl_F.masked_mean((corr==0).float(), response_mask) pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.) def compute_policy_loss_kl_cov( old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", k_ratio=0.0002, ppo_kl_coef=1, ): """ Compute the clipped policy objective and related metrics for Clip-Cov. Adapted from https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py Args: old_log_prob (torch.Tensor): Log-probabilities of actions under the old policy, shape (batch_size, response_length). log_prob (torch.Tensor): Log-probabilities of actions under the current policy, shape (batch_size, response_length). advantages (torch.Tensor): Advantage estimates for each action, shape (batch_size, response_length). response_mask (torch.Tensor): Mask indicating which tokens to include in the loss, shape (batch_size, response_length). loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". k_ratio (float, optional): Ratio for selecting the top-k covariance values. Defaults to 0.0002. ppo_kl_coef (float, optional): Coefficient for the KL penalty term in the loss. Defaults to 1. """ assert k_ratio > 0, "k_ratio should be larger than 0." negative_approx_kl = log_prob - old_log_prob abs_kl = negative_approx_kl.abs() ratio = torch.exp(negative_approx_kl) ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) pg_losses1 = -advantages * ratio pg_losses_kl = - advantages * ratio + ppo_kl_coef * abs_kl pg_losses = pg_losses1 all_valid = (response_mask > 0) all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() k = min(k_ratio, len(all_valid_adv)) if k != 0: cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) k_percent_nums = max(1, int(len(cov_lst_all) * k_ratio)) large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices if len(large_cov_idxs) != 0: large_cov_idxs = all_valid_idx[large_cov_idxs] pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return pg_loss, torch.tensor(0.), ppo_kl_abs, torch.tensor(0.) ``` in the `dp_actor.py`, we add the loss mode switch feature: ``` loss_mode = self.config.get("loss_mode", "vanilla") if loss_mode not in ["vanilla", "clip_cov", "kl_cov"]: raise ValueError(f"Unsupported loss mode: {loss_mode}. Supported modes are: 'vanilla', 'clip_cov', 'kl_cov'.") if loss_mode == "vanilla": pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, response_mask=response_mask, cliprange=clip_ratio, cliprange_low=clip_ratio_low, cliprange_high=clip_ratio_high, clip_ratio_c=clip_ratio_c, loss_agg_mode=loss_agg_mode, ) elif loss_mode == "clip_cov": pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower= compute_policy_loss_clip_cov( old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, response_mask=response_mask, cliprange=clip_ratio, cliprange_low=clip_ratio_low, cliprange_high=clip_ratio_high, loss_agg_mode=loss_agg_mode, clip_ratio=self.config.clip_cov_ratio, clip_cov_lb=self.config.clip_cov_lb, clip_cov_ub=self.config.clip_cov_ub, ) elif loss_mode == "kl_cov": pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower= compute_policy_loss_kl_cov( old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, response_mask=response_mask, loss_agg_mode=loss_agg_mode, k_ratio=self.config.k_ratio, ppo_kl_coef=self.config.ppo_kl_coef, ) ``` ### Usage Example > Provide usage example(s) for easier usage. We create a recipe (built on dapo recipe) named entropy to store our scripts, for example the `7b_kl_cov.sh`: ``` #!/usr/bin/env bash set -xeuo pipefail export WANDB_API_KEY=YOUR_WANDB_API_KEY # export VLLM_USE_V1=1 project_name='Qwen2.5-7B' exp_name='klcov' adv_estimator=grpo use_kl_in_reward=False kl_coef=0.0 use_kl_loss=False kl_loss_coef=0.0 clip_ratio_low=0.2 clip_ratio_high=0.2 max_prompt_length=$((1024 * 2)) max_response_length=$((1024 * 8)) enable_overlong_buffer=False overlong_buffer_len=$((1024 * 2)) overlong_penalty_factor=1.0 loss_agg_mode="token-mean" loss_mode="kl_cov" enable_filter_groups=False filter_groups_metric=acc max_num_gen_batches=10 train_prompt_bsz=256 gen_prompt_bsz=$((train_prompt_bsz * 3)) train_prompt_mini_bsz=256 n_resp_per_prompt=8 max_token=20480 # Ray RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} WORKING_DIR=${WORKING_DIR:-"${PWD}"} RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} NNODES=${NNODES:-4} # Paths RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} # Algorithm temperature=1.0 top_p=1.0 top_k=-1 # 0 for HF rollout, -1 for vLLM rollout ppo_kl_coef=1 k_ratio=0.002 # Mathematically equivalent use_dynamic_bsz=True infer_micro_batch_size=null train_micro_batch_size=null offload=False HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ data.prompt_key=prompt \ data.truncation='left' \ data.filter_overlong_prompts=False \ data.max_prompt_length=${max_prompt_length} \ data.max_response_length=${max_response_length} \ data.gen_batch_size=${gen_prompt_bsz} \ data.train_batch_size=${train_prompt_bsz} \ data.return_raw_chat=True \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ actor_rollout_ref.actor.loss_mode=${loss_mode} \ actor_rollout_ref.actor.k_ratio=${k_ratio} \ actor_rollout_ref.actor.ppo_kl_coef=${ppo_kl_coef} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.rollout.mode=sync \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ algorithm.filter_groups.enable=${enable_filter_groups} \ algorithm.filter_groups.metric=${filter_groups_metric} \ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.optim.weight_decay=0 \ actor_rollout_ref.actor.optim.warmup_style=constant \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ actor_rollout_ref.rollout.temperature=${temperature} \ actor_rollout_ref.rollout.top_p=${top_p} \ actor_rollout_ref.rollout.top_k="${top_k}" \ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ actor_rollout_ref.rollout.val_kwargs.do_sample=False \ actor_rollout_ref.rollout.val_kwargs.n=1 \ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ reward_model.reward_manager=dapo \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ trainer.logger=['console','wandb'] \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ trainer.nnodes="${NNODES}" \ trainer.val_before_train=False \ trainer.test_freq=4 \ trainer.save_freq=32 \ trainer.total_epochs=1000 \ trainer.default_local_dir="${CKPTS_DIR}" \ trainer.resume_mode=disable ``` ### Test Please refer to the Fig 11 and Tab 2 in https://arxiv.org/pdf/2505.22617 for detailed results. ### Additional Info. NA ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary. --------- Co-authored-by: Jiacheng Chen <jackchan9345@gmail.com> Co-authored-by: H <linhaibin.eric@gmail.com>
This commit is contained in:
@ -217,6 +217,7 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The
|
||||
- [LUFFY](https://arxiv.org/pdf/2504.14945): Learning to Reason under Off-Policy Guidance
|
||||
- [verl-tool](https://github.com/TIGER-AI-Lab/verl-tool): An unified and easy-to-extend tool-agent training framework based on verl
|
||||
- [DeepMath](https://github.com/zwhe99/DeepMath): DeepMath-103K data and series models for math reasoning
|
||||
- [Entropy Mechanism of RL](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL): The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning
|
||||
- [LLaSA-TTS-GRPO](https://github.com/channel-io/ch-tts-llasa-rl-grpo): TTS fine-tuning with GRPO optimization based on LLASA models 
|
||||
- [RL-Factory](https://github.com/Simple-Efficient/RL-Factory): An easy and efficient RL post-training framework for Agentic Learning 
|
||||
- [RACRO](https://github.com/gyhdog99/RACRO2): Build multi-modal reasoning models via decoupling it into query-conditioned captioning and text-only reasoning 
|
||||
|
114
docs/algo/entropy.md
Normal file
114
docs/algo/entropy.md
Normal file
@ -0,0 +1,114 @@
|
||||
# Recipe: Entropy Mechanism
|
||||
|
||||
|
||||
|
||||
<div align="center">
|
||||
|
||||
The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.
|
||||
|
||||
[](https://arxiv.org/pdf/2505.22617) [](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [](https://www.alphaxiv.org/abs/2505.22617) [](https://x.com/stingning/status/1928088554166505667) [](https://x.com/charlesfornlp/status/1928089451080585283) [](https://x.com/_akhaliq/status/1928077929105268861)
|
||||
|
||||
|
||||
<div align="center" style="font-family: Arial, sans-serif;">
|
||||
<p>
|
||||
<a href="#🎉news" style="text-decoration: none; font-weight: bold;">🎉 News</a> •
|
||||
<a href="#✨getting-started" style="text-decoration: none; font-weight: bold;">✨ Getting Started</a> •
|
||||
<a href="#📖introduction" style="text-decoration: none; font-weight: bold;">📖 Introduction</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="#🎈citation" style="text-decoration: none; font-weight: bold;">🎈 Citation</a> •
|
||||
<a href="#🌻acknowledgement" style="text-decoration: none; font-weight: bold;">🌻 Acknowledgement</a> •
|
||||
<a href="#📬Contact" style="text-decoration: none; font-weight: bold;">📬 Contact</a> •
|
||||
<a href="#📈star-history" style="text-decoration: none; font-weight: bold;">📈 Star History</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
# 🎉News
|
||||
|
||||
- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).
|
||||
- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse.
|
||||
|
||||
|
||||
|
||||
# ✨Getting started
|
||||
|
||||
After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:
|
||||
|
||||
```
|
||||
cd verl
|
||||
conda activate your_env
|
||||
bash recipe/dapo/7b_kl_cov.sh
|
||||
```
|
||||
|
||||
While for training Qwen2.5-32B on multi nodes, you can run the following commands:
|
||||
|
||||
```
|
||||
cd verl
|
||||
conda activate your_env
|
||||
bash recipe/dapo/32b_kl_cov.sh
|
||||
```
|
||||
|
||||
# 📖Introduction
|
||||
|
||||
<div align="left">
|
||||
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
|
||||
</div>
|
||||
|
||||
This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion.
|
||||
|
||||
<div align="left">
|
||||
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
|
||||
</div>
|
||||
|
||||
Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose Clip-Cov and KL-Cov, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance.
|
||||
|
||||
# 📃Evaluation
|
||||
|
||||
<div align="left">
|
||||
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
|
||||
</div>
|
||||
|
||||
|
||||
Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL.
|
||||
| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |
|
||||
| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |
|
||||
| *Qwen2.5-7B* | | | | | | | | |
|
||||
| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 |
|
||||
| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 |
|
||||
| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 |
|
||||
| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** |
|
||||
| *Qwen2.5-32B* | | | | | | | | |
|
||||
| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 |
|
||||
| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 |
|
||||
| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 |
|
||||
| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** |
|
||||
|
||||
Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.
|
||||
|
||||
|
||||
# 🎈Citation
|
||||
If you find this paper or repo helpful, please cite us.
|
||||
|
||||
```bibtex
|
||||
@article{cui2025entropy,
|
||||
title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},
|
||||
author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},
|
||||
journal={arXiv preprint arXiv:2505.22617},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
# 🌻Acknowledgement
|
||||
We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!
|
||||
|
||||
# 📬 Contact
|
||||
|
||||
For questions, discussion, or collaboration opportunities, feel free to contact:
|
||||
- Ganqu Cui: cuiganqu@pjlab.org.cn
|
||||
- Yuchen Zhang: yuchen.zhang2003@gmail.com
|
||||
- Jiacheng Chen: jackchan9345@gmail.com
|
||||
- Ning Ding: ningding.cs@gmail.com
|
||||
|
@ -70,6 +70,7 @@ verl is fast with:
|
||||
algo/dapo.md
|
||||
algo/spin.md
|
||||
algo/sppo.md
|
||||
algo/entropy.md
|
||||
algo/opo.md
|
||||
algo/baseline.md
|
||||
|
||||
|
148
recipe/entropy/32b_clip_cov.sh
Normal file
148
recipe/entropy/32b_clip_cov.sh
Normal file
@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
export WANDB_API_KEY=YOUR_WANDB_API_KEY
|
||||
# export VLLM_USE_V1=1
|
||||
|
||||
project_name='Qwen2.5-32B'
|
||||
exp_name='clipcov'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=1
|
||||
clip_ratio_high=1
|
||||
clip_cov_ratio=0.0002
|
||||
clip_cov_lb=1.0
|
||||
clip_cov_ub=5.0
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=False
|
||||
overlong_buffer_len=$((1024 * 2))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
loss_mode="clip_cov"
|
||||
enable_filter_groups=True
|
||||
filter_groups_metric=acc
|
||||
max_num_gen_batches=10
|
||||
train_prompt_bsz=256
|
||||
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
||||
train_prompt_mini_bsz=32
|
||||
n_resp_per_prompt=8
|
||||
max_token=20480
|
||||
|
||||
# Ray
|
||||
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-4}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
|
||||
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
ppo_kl_coef=1
|
||||
kl_cov_ratio=0.02
|
||||
|
||||
# Mathematically equivalent
|
||||
use_dynamic_bsz=True
|
||||
infer_micro_batch_size=null
|
||||
train_micro_batch_size=null
|
||||
offload=False
|
||||
|
||||
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.filter_overlong_prompts=False \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
|
||||
actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \
|
||||
actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \
|
||||
actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.mode=sync \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
algorithm.filter_groups.enable=${enable_filter_groups} \
|
||||
algorithm.filter_groups.metric=${filter_groups_metric} \
|
||||
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.actor.clip_cov_ratio=${clip_cov_ratio} \
|
||||
actor_rollout_ref.actor.clip_cov_lb=${clip_cov_lb} \
|
||||
actor_rollout_ref.actor.clip_cov_ub=${clip_cov_ub} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k="${top_k}" \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
||||
reward_model.reward_manager=dapo \
|
||||
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
||||
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
||||
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.val_before_train=False \
|
||||
trainer.test_freq=4 \
|
||||
trainer.save_freq=32 \
|
||||
trainer.total_epochs=1000 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=disable
|
142
recipe/entropy/32b_kl_cov.sh
Normal file
142
recipe/entropy/32b_kl_cov.sh
Normal file
@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
export WANDB_API_KEY=YOUR_WANDB_API_KEY
|
||||
# export VLLM_USE_V1=1
|
||||
|
||||
project_name='Qwen2.5-32B'
|
||||
exp_name='klcov'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.2
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=False
|
||||
overlong_buffer_len=$((1024 * 2))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
loss_mode="kl_cov"
|
||||
enable_filter_groups=True
|
||||
filter_groups_metric=acc
|
||||
max_num_gen_batches=10
|
||||
train_prompt_bsz=256
|
||||
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
||||
train_prompt_mini_bsz=32
|
||||
n_resp_per_prompt=8
|
||||
max_token=20480
|
||||
|
||||
# Ray
|
||||
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-4}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
|
||||
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
ppo_kl_coef=1
|
||||
kl_cov_ratio=0.0002
|
||||
|
||||
# Mathematically equivalent
|
||||
use_dynamic_bsz=True
|
||||
infer_micro_batch_size=null
|
||||
train_micro_batch_size=null
|
||||
offload=False
|
||||
|
||||
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.filter_overlong_prompts=False \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.actor.loss_mode=${loss_mode} \
|
||||
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
|
||||
actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \
|
||||
actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.mode=sync \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
algorithm.filter_groups.enable=${enable_filter_groups} \
|
||||
algorithm.filter_groups.metric=${filter_groups_metric} \
|
||||
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k="${top_k}" \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
||||
reward_model.reward_manager=dapo \
|
||||
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
||||
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
||||
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.val_before_train=False \
|
||||
trainer.test_freq=4 \
|
||||
trainer.save_freq=32 \
|
||||
trainer.total_epochs=1000 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=disable
|
141
recipe/entropy/32b_kl_cov_mininbsz.sh
Normal file
141
recipe/entropy/32b_kl_cov_mininbsz.sh
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
export WANDB_API_KEY=YOUR_WANDB_API_KEY
|
||||
# export VLLM_USE_V1=1
|
||||
|
||||
project_name='Qwen2.5-32B'
|
||||
exp_name='klcov'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.2
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=False
|
||||
overlong_buffer_len=$((1024 * 2))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
loss_mode="kl_cov"
|
||||
enable_filter_groups=True
|
||||
filter_groups_metric=acc
|
||||
max_num_gen_batches=10
|
||||
train_prompt_bsz=256
|
||||
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
||||
train_prompt_mini_bsz=16
|
||||
n_resp_per_prompt=8
|
||||
max_token=20480
|
||||
|
||||
# Ray
|
||||
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-4}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
|
||||
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
ppo_kl_coef=1
|
||||
kl_cov_ratio=0.0002
|
||||
|
||||
# Mathematically equivalent
|
||||
use_dynamic_bsz=True
|
||||
infer_micro_batch_size=null
|
||||
train_micro_batch_size=null
|
||||
offload=False
|
||||
|
||||
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.filter_overlong_prompts=False \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
|
||||
actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \
|
||||
actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.mode=sync \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
algorithm.filter_groups.enable=${enable_filter_groups} \
|
||||
algorithm.filter_groups.metric=${filter_groups_metric} \
|
||||
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k="${top_k}" \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
||||
reward_model.reward_manager=dapo \
|
||||
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
||||
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
||||
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.val_before_train=False \
|
||||
trainer.test_freq=4 \
|
||||
trainer.save_freq=32 \
|
||||
trainer.total_epochs=1000 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=disable
|
145
recipe/entropy/7b_clip_cov.sh
Normal file
145
recipe/entropy/7b_clip_cov.sh
Normal file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
export WANDB_API_KEY=YOUR_WANDB_API_KEY
|
||||
# export VLLM_USE_V1=1
|
||||
|
||||
project_name='Qwen2.5-7B'
|
||||
exp_name='clipcov'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=1
|
||||
clip_ratio_high=1
|
||||
clip_cov_ratio=0.0002
|
||||
clip_cov_lb=1.0
|
||||
clip_cov_ub=5.0
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=False
|
||||
overlong_buffer_len=$((1024 * 2))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
loss_mode="clip_cov"
|
||||
enable_filter_groups=True
|
||||
filter_groups_metric=acc
|
||||
max_num_gen_batches=10
|
||||
train_prompt_bsz=256
|
||||
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
||||
train_prompt_mini_bsz=32
|
||||
n_resp_per_prompt=8
|
||||
max_token=30720
|
||||
|
||||
# Ray
|
||||
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-4}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
|
||||
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
ppo_kl_coef=1
|
||||
kl_cov_ratio=0.2
|
||||
|
||||
# Mathematically equivalent
|
||||
use_dynamic_bsz=True
|
||||
infer_micro_batch_size=null
|
||||
train_micro_batch_size=null
|
||||
offload=False
|
||||
|
||||
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.filter_overlong_prompts=False \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
|
||||
actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \
|
||||
actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \
|
||||
actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.mode=sync \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
algorithm.filter_groups.enable=${enable_filter_groups} \
|
||||
algorithm.filter_groups.metric=${filter_groups_metric} \
|
||||
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k="${top_k}" \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
||||
reward_model.reward_manager=dapo \
|
||||
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
||||
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
||||
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.val_before_train=False \
|
||||
trainer.test_freq=4 \
|
||||
trainer.save_freq=32 \
|
||||
trainer.total_epochs=1000 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=disable
|
141
recipe/entropy/7b_kl_cov.sh
Normal file
141
recipe/entropy/7b_kl_cov.sh
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
export WANDB_API_KEY=YOUR_WANDB_API_KEY
|
||||
# export VLLM_USE_V1=1
|
||||
|
||||
project_name='Qwen2.5-7B'
|
||||
exp_name='klcov'
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.2
|
||||
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=False
|
||||
overlong_buffer_len=$((1024 * 2))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
loss_agg_mode="token-mean"
|
||||
loss_mode="kl_cov"
|
||||
enable_filter_groups=True
|
||||
filter_groups_metric=acc
|
||||
max_num_gen_batches=10
|
||||
train_prompt_bsz=256
|
||||
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
||||
train_prompt_mini_bsz=32
|
||||
n_resp_per_prompt=8
|
||||
max_token=30720
|
||||
|
||||
# Ray
|
||||
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
NNODES=${NNODES:-4}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
|
||||
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
ppo_kl_coef=1
|
||||
kl_cov_ratio=0.002
|
||||
|
||||
# Mathematically equivalent
|
||||
use_dynamic_bsz=True
|
||||
infer_micro_batch_size=null
|
||||
train_micro_batch_size=null
|
||||
offload=False
|
||||
|
||||
HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.filter_overlong_prompts=False \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
|
||||
actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \
|
||||
actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.mode=sync \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
algorithm.filter_groups.enable=${enable_filter_groups} \
|
||||
algorithm.filter_groups.metric=${filter_groups_metric} \
|
||||
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k="${top_k}" \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=False \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
||||
reward_model.reward_manager=dapo \
|
||||
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
||||
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
||||
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.val_before_train=False \
|
||||
trainer.test_freq=4 \
|
||||
trainer.save_freq=32 \
|
||||
trainer.total_epochs=1000 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=disable
|
110
recipe/entropy/README.md
Normal file
110
recipe/entropy/README.md
Normal file
@ -0,0 +1,110 @@
|
||||
<div align="center">
|
||||
|
||||
# The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.
|
||||
|
||||
[](https://arxiv.org/pdf/2505.22617) [](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [](https://www.alphaxiv.org/abs/2505.22617) [](https://x.com/stingning/status/1928088554166505667) [](https://x.com/charlesfornlp/status/1928089451080585283) [](https://x.com/_akhaliq/status/1928077929105268861)
|
||||
|
||||
|
||||
<div align="center" style="font-family: Arial, sans-serif;">
|
||||
<p>
|
||||
<a href="#🎉news" style="text-decoration: none; font-weight: bold;">🎉 News</a> •
|
||||
<a href="#✨getting-started" style="text-decoration: none; font-weight: bold;">✨ Getting Started</a> •
|
||||
<a href="#📖introduction" style="text-decoration: none; font-weight: bold;">📖 Introduction</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="#🎈citation" style="text-decoration: none; font-weight: bold;">🎈 Citation</a> •
|
||||
<a href="#🌻acknowledgement" style="text-decoration: none; font-weight: bold;">🌻 Acknowledgement</a> •
|
||||
<a href="#📬Contact" style="text-decoration: none; font-weight: bold;">📬 Contact</a> •
|
||||
<a href="#📈star-history" style="text-decoration: none; font-weight: bold;">📈 Star History</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
# 🎉News
|
||||
|
||||
- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).
|
||||
- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse.
|
||||
|
||||
|
||||
|
||||
# ✨Getting started
|
||||
|
||||
After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:
|
||||
|
||||
```
|
||||
cd verl
|
||||
conda activate your_env
|
||||
bash recipe/dapo/7b_kl_cov.sh
|
||||
```
|
||||
|
||||
While for training Qwen2.5-32B on multi nodes, you can run the following commands:
|
||||
|
||||
```
|
||||
cd verl
|
||||
conda activate your_env
|
||||
bash recipe/dapo/32b_kl_cov.sh
|
||||
```
|
||||
|
||||
# 📖Introduction
|
||||
|
||||
<div align="left">
|
||||
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
|
||||
</div>
|
||||
|
||||
This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion.
|
||||
|
||||
<div align="left">
|
||||
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
|
||||
</div>
|
||||
|
||||
Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose Clip-Cov and KL-Cov, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance.
|
||||
|
||||
# 📃Evaluation
|
||||
|
||||
<div align="left">
|
||||
<img src="https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true" alt="issue" style="width: 96%; height: auto;">
|
||||
</div>
|
||||
|
||||
|
||||
Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL.
|
||||
| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |
|
||||
| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |
|
||||
| *Qwen2.5-7B* | | | | | | | | |
|
||||
| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 |
|
||||
| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 |
|
||||
| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 |
|
||||
| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** |
|
||||
| *Qwen2.5-32B* | | | | | | | | |
|
||||
| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 |
|
||||
| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 |
|
||||
| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 |
|
||||
| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** |
|
||||
|
||||
Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.
|
||||
|
||||
|
||||
# 🎈Citation
|
||||
If you find this paper or repo helpful, please cite us.
|
||||
|
||||
```bibtex
|
||||
@article{cui2025entropy,
|
||||
title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},
|
||||
author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},
|
||||
journal={arXiv preprint arXiv:2505.22617},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
# 🌻Acknowledgement
|
||||
We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!
|
||||
|
||||
# 📬 Contact
|
||||
|
||||
For questions, discussion, or collaboration opportunities, feel free to contact:
|
||||
- Ganqu Cui: cuiganqu@pjlab.org.cn
|
||||
- Yuchen Zhang: yuchen.zhang2003@gmail.com
|
||||
- Jiacheng Chen: jackchan9345@gmail.com
|
||||
- Ning Ding: ningding.cs@gmail.com
|
||||
|
39
recipe/entropy/config/entropy_trainer.yaml
Normal file
39
recipe/entropy/config/entropy_trainer.yaml
Normal file
@ -0,0 +1,39 @@
|
||||
hydra:
|
||||
searchpath:
|
||||
- file://verl/trainer/config
|
||||
|
||||
defaults:
|
||||
- ppo_trainer
|
||||
- _self_
|
||||
|
||||
data:
|
||||
gen_batch_size: ${data.train_batch_size}
|
||||
|
||||
reward_model:
|
||||
reward_kwargs:
|
||||
overlong_buffer_cfg: $reward_model.overlong_buffer
|
||||
reward_manager: dapo
|
||||
overlong_buffer:
|
||||
enable: False
|
||||
len: 0
|
||||
penalty_factor: 0.0
|
||||
log: False
|
||||
|
||||
algorithm:
|
||||
filter_groups:
|
||||
enable: False # We try to avoid forgetting to set enable
|
||||
metric: null # acc / score / seq_reward / seq_final_reward / ...
|
||||
max_num_gen_batches: 0 # Non-positive values mean no upper limit
|
||||
|
||||
trainer:
|
||||
project_name: verl-entropy
|
||||
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
policy_loss:
|
||||
loss_mode: "vanilla" # /clip-cov / kl-cov from https://arxiv.org/abs/2505.
|
||||
clip_cov_ratio: 0.0002 # for clip-cov loss
|
||||
clip_cov_lb: 1.0 # for clip-cov loss
|
||||
clip_cov_ub: 5.0 # for clip-cov loss
|
||||
kl_cov_ratio: 0.0002 # for kl-cov loss
|
||||
ppo_kl_coef: 0.1 # for kl-cov loss
|
311
recipe/entropy/entropy_ray_trainer.py
Normal file
311
recipe/entropy/entropy_ray_trainer.py
Normal file
@ -0,0 +1,311 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
FSDP PPO Trainer with Ray-based single controller.
|
||||
This trainer supports model-agonistic model initialization with huggingface
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pprint import pprint
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.metric_utils import (
|
||||
compute_data_metrics,
|
||||
compute_throughout_metrics,
|
||||
compute_timing_metrics,
|
||||
reduce_metrics,
|
||||
)
|
||||
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
|
||||
from verl.utils.debug import simple_timer
|
||||
|
||||
|
||||
class RayEntropyTrainer(RayPPOTrainer):
|
||||
"""
|
||||
Note that this trainer runs on the driver process on a single CPU/GPU node.
|
||||
"""
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
The training loop of PPO.
|
||||
The driver process only need to call the compute functions of the worker group through RPC
|
||||
to construct the PPO dataflow.
|
||||
The light-weight advantage computation is done on the driver process.
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.utils.tracking import Tracking
|
||||
|
||||
logger = Tracking(
|
||||
project_name=self.config.trainer.project_name,
|
||||
experiment_name=self.config.trainer.experiment_name,
|
||||
default_backend=self.config.trainer.logger,
|
||||
config=OmegaConf.to_container(self.config, resolve=True),
|
||||
)
|
||||
|
||||
self.global_steps = 0
|
||||
|
||||
# load checkpoint before doing anything
|
||||
self._load_checkpoint()
|
||||
|
||||
# perform validation before training
|
||||
# currently, we only support validation using the reward_function.
|
||||
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
||||
val_metrics = self._validate()
|
||||
assert val_metrics, f"{val_metrics=}"
|
||||
pprint(f"Initial validation metrics: {val_metrics}")
|
||||
logger.log(data=val_metrics, step=self.global_steps)
|
||||
if self.config.trainer.get("val_only", False):
|
||||
return
|
||||
|
||||
# add tqdm
|
||||
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
|
||||
|
||||
# we start from step 1
|
||||
self.global_steps += 1
|
||||
last_val_metrics = None
|
||||
|
||||
timing_raw = defaultdict(float)
|
||||
batch = None
|
||||
num_prompt_in_batch = 0
|
||||
num_gen_batches = 0
|
||||
for epoch in range(self.config.trainer.total_epochs):
|
||||
for batch_dict in self.train_dataloader:
|
||||
metrics = {}
|
||||
|
||||
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
|
||||
num_gen_batches += 1
|
||||
# pop those keys for generation
|
||||
if "multi_modal_inputs" in new_batch.non_tensor_batch.keys():
|
||||
gen_batch = new_batch.pop(
|
||||
batch_keys=["input_ids", "attention_mask", "position_ids"],
|
||||
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
|
||||
)
|
||||
else:
|
||||
gen_batch = new_batch.pop(
|
||||
batch_keys=["input_ids", "attention_mask", "position_ids"],
|
||||
non_tensor_batch_keys=["raw_prompt_ids"],
|
||||
)
|
||||
|
||||
is_last_step = self.global_steps >= self.total_training_steps
|
||||
|
||||
with simple_timer("step", timing_raw):
|
||||
# generate a batch
|
||||
# with simple_timer("gen", timing_raw):
|
||||
# gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
|
||||
with simple_timer("gen", timing_raw):
|
||||
if not self.async_rollout_mode:
|
||||
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
|
||||
else:
|
||||
self.async_rollout_manager.wake_up()
|
||||
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
|
||||
self.async_rollout_manager.sleep()
|
||||
|
||||
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
|
||||
with simple_timer("gen_max", timing_raw):
|
||||
gen_baseline_batch = deepcopy(gen_batch)
|
||||
gen_baseline_batch.meta_info["do_sample"] = False
|
||||
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
|
||||
|
||||
new_batch = new_batch.union(gen_baseline_output)
|
||||
reward_baseline_tensor = self.reward_fn(new_batch)
|
||||
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
|
||||
|
||||
new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
|
||||
|
||||
new_batch.batch["reward_baselines"] = reward_baseline_tensor
|
||||
|
||||
del gen_baseline_batch, gen_baseline_output
|
||||
|
||||
new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
|
||||
# repeat to align with repeated responses in rollout
|
||||
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
new_batch = new_batch.union(gen_batch_output)
|
||||
|
||||
with simple_timer("reward", timing_raw):
|
||||
# compute scores. Support both model and function-based.
|
||||
# We first compute the scores using reward model. Then, we call reward_fn to combine
|
||||
# the results from reward model and rule-based results.
|
||||
if self.use_rm:
|
||||
# we first compute reward model score
|
||||
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
|
||||
new_batch = new_batch.union(reward_tensor)
|
||||
|
||||
# we combine with rule-based rm
|
||||
reward_extra_infos_dict: dict[str, list]
|
||||
try:
|
||||
reward_result = self.reward_fn(new_batch, return_dict=True)
|
||||
reward_tensor = reward_result["reward_tensor"]
|
||||
reward_extra_infos_dict = reward_result["reward_extra_info"]
|
||||
except Exception as e:
|
||||
print(f"Error in reward_fn: {e}")
|
||||
reward_tensor = self.reward_fn(new_batch)
|
||||
reward_extra_infos_dict = {}
|
||||
|
||||
new_batch.batch["token_level_scores"] = reward_tensor
|
||||
|
||||
print(f"{list(reward_extra_infos_dict.keys())=}")
|
||||
if reward_extra_infos_dict:
|
||||
new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
|
||||
|
||||
# compute rewards. apply_kl_penalty if available
|
||||
if self.config.algorithm.use_kl_in_reward:
|
||||
new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
|
||||
metrics.update(kl_metrics) # TODO: This will be cleared if we use multiple genenration batches
|
||||
else:
|
||||
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
|
||||
|
||||
if not self.config.algorithm.filter_groups.enable:
|
||||
batch = new_batch
|
||||
else: # NOTE: When prompts after filtering is less than train batch size,
|
||||
# we skip to the next generation batch
|
||||
metric_name = self.config.algorithm.filter_groups.metric
|
||||
if metric_name == "seq_final_reward":
|
||||
# Turn to numpy for easier filtering
|
||||
new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
|
||||
elif metric_name == "seq_reward":
|
||||
new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
|
||||
|
||||
# Collect the sequence reward for each trajectory
|
||||
prompt_uid2metric_vals = defaultdict(list)
|
||||
for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]):
|
||||
prompt_uid2metric_vals[uid].append(metric_val)
|
||||
|
||||
prompt_uid2metric_std = {}
|
||||
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
|
||||
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
|
||||
|
||||
kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1]
|
||||
num_prompt_in_batch += len(kept_prompt_uids)
|
||||
|
||||
kept_traj_idxs = []
|
||||
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
|
||||
if traj_from_prompt_uid in kept_prompt_uids:
|
||||
kept_traj_idxs.append(idx)
|
||||
|
||||
new_batch = new_batch[kept_traj_idxs]
|
||||
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
|
||||
|
||||
prompt_bsz = self.config.data.train_batch_size
|
||||
if num_prompt_in_batch < prompt_bsz:
|
||||
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
|
||||
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
|
||||
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
|
||||
print(f"{num_gen_batches=}. Keep generating...")
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.")
|
||||
else:
|
||||
# Align the batch
|
||||
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
|
||||
print(f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. Collecting finished.")
|
||||
batch = batch[:traj_bsz]
|
||||
|
||||
# === Updating ===
|
||||
|
||||
batch.batch["response_mask"] = compute_response_mask(batch)
|
||||
|
||||
# balance the number of valid tokens on each dp rank.
|
||||
# Note that this breaks the order of data inside the batch.
|
||||
# Please take care when you implement group based adv computation such as GRPO and rloo
|
||||
if self.config.trainer.balance_batch:
|
||||
self._balance_batch(batch, metrics=metrics)
|
||||
|
||||
# compute global_valid tokens
|
||||
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
|
||||
|
||||
# recompute old_log_probs
|
||||
with simple_timer("old_log_prob", timing_raw):
|
||||
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
|
||||
batch = batch.union(old_log_prob)
|
||||
|
||||
if self.use_reference_policy:
|
||||
# compute reference log_prob
|
||||
with simple_timer("ref", timing_raw):
|
||||
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
||||
batch = batch.union(ref_log_prob)
|
||||
|
||||
# compute values
|
||||
if self.use_critic:
|
||||
with simple_timer("values", timing_raw):
|
||||
values = self.critic_wg.compute_values(batch)
|
||||
batch = batch.union(values)
|
||||
|
||||
with simple_timer("adv", timing_raw):
|
||||
# compute advantages, executed on the driver process
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
|
||||
batch = compute_advantage(
|
||||
batch,
|
||||
adv_estimator=self.config.algorithm.adv_estimator,
|
||||
gamma=self.config.algorithm.gamma,
|
||||
lam=self.config.algorithm.lam,
|
||||
num_repeat=self.config.actor_rollout_ref.rollout.n,
|
||||
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
|
||||
)
|
||||
|
||||
# update critic
|
||||
if self.use_critic:
|
||||
with simple_timer("update_critic", timing_raw):
|
||||
critic_output = self.critic_wg.update_critic(batch)
|
||||
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
|
||||
metrics.update(critic_output_metrics)
|
||||
|
||||
# implement critic warmup
|
||||
if self.config.trainer.critic_warmup <= self.global_steps:
|
||||
# update actor
|
||||
with simple_timer("update_actor", timing_raw):
|
||||
actor_output = self.actor_rollout_wg.update_actor(batch)
|
||||
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
|
||||
metrics.update(actor_output_metrics)
|
||||
|
||||
# validate
|
||||
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
|
||||
with simple_timer("testing", timing_raw):
|
||||
val_metrics: dict = self._validate()
|
||||
if is_last_step:
|
||||
last_val_metrics = val_metrics
|
||||
metrics.update(val_metrics)
|
||||
|
||||
if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
|
||||
with simple_timer("save_checkpoint", timing_raw):
|
||||
self._save_checkpoint()
|
||||
|
||||
# collect metrics
|
||||
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
|
||||
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
|
||||
# TODO: implement actual tflpo and theoretical tflpo
|
||||
n_gpus = self.resource_pool_manager.get_n_gpus()
|
||||
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
||||
timing_raw = defaultdict(float) # clear timing
|
||||
|
||||
metrics["train/num_gen_batches"] = num_gen_batches
|
||||
batch = None
|
||||
num_prompt_in_batch = 0
|
||||
num_gen_batches = 0
|
||||
|
||||
# TODO: make a canonical logger that supports various backend
|
||||
logger.log(data=metrics, step=self.global_steps)
|
||||
|
||||
if is_last_step:
|
||||
pprint(f"Final validation metrics: {last_val_metrics}")
|
||||
progress_bar.close()
|
||||
return
|
||||
|
||||
progress_bar.update(1)
|
||||
self.global_steps += 1
|
227
recipe/entropy/main_entropy.py
Normal file
227
recipe/entropy/main_entropy.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
|
||||
"""
|
||||
|
||||
import hydra
|
||||
import ray
|
||||
|
||||
from verl.trainer.ppo.reward import load_reward_manager
|
||||
|
||||
from .entropy_ray_trainer import RayEntropyTrainer
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="entropy_trainer", version_base=None)
|
||||
def main(config):
|
||||
run_ppo(config)
|
||||
|
||||
|
||||
def run_ppo(config) -> None:
|
||||
if not ray.is_initialized():
|
||||
# this is for local ray cluster
|
||||
ray.init(
|
||||
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "WANDB_API_KEY": "YOUR_WANDB_API_KEY"}},
|
||||
num_cpus=config.ray_init.num_cpus,
|
||||
)
|
||||
|
||||
runner = TaskRunner.remote()
|
||||
ray.get(runner.run.remote(config))
|
||||
|
||||
|
||||
def merge_dict(a: dict, b: dict) -> dict:
|
||||
"""Return a new dict that has `a` updated with `b` (b wins on conflicts).
|
||||
|
||||
Example::
|
||||
|
||||
>>> d1 = {"x": 1, "y": 2}
|
||||
>>> d2 = {"y": 20, "z": 3}
|
||||
>>> new_dict = merge_dict(d1, d2)
|
||||
>>> print(new_dict) # {'x': 1, 'y': 20, 'z': 3}
|
||||
>>> print(d1) # {"x": 1, "y": 2} (unchanged)
|
||||
>>> print(d2) # {"y": 20, "z": 3} (unchanged)
|
||||
"""
|
||||
return a | b
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
|
||||
class TaskRunner:
|
||||
def run(self, config):
|
||||
# print initial config
|
||||
from pprint import pprint
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.utils.fs import copy_to_local
|
||||
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
||||
print(f"{config.actor_rollout_ref.model.path}")
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
||||
|
||||
# define worker classes
|
||||
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
|
||||
assert config.critic.strategy in ["fsdp", "fsdp2"]
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
|
||||
|
||||
actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
|
||||
ray_worker_group_cls = RayWorkerGroup
|
||||
|
||||
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
|
||||
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
||||
|
||||
actor_rollout_cls = ActorRolloutRefWorker
|
||||
ray_worker_group_cls = NVMegatronRayWorkerGroup
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
||||
|
||||
role_worker_mapping = {
|
||||
Role.ActorRollout: ray.remote(actor_rollout_cls),
|
||||
Role.Critic: ray.remote(CriticWorker),
|
||||
}
|
||||
|
||||
global_pool_id = "global_pool"
|
||||
resource_pool_spec = {
|
||||
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
||||
}
|
||||
mapping = {
|
||||
Role.ActorRollout: global_pool_id,
|
||||
Role.Critic: global_pool_id,
|
||||
}
|
||||
|
||||
# we should adopt a multi-source reward function here
|
||||
# - for rule-based rm, we directly call a reward score
|
||||
# - for model-based rm, we call a model
|
||||
# - for code related prompt, we send to a sandbox if there are test cases
|
||||
# - finally, we combine all the rewards together
|
||||
# - The reward type depends on the tag of the data
|
||||
if config.reward_model.enable:
|
||||
if config.reward_model.strategy in ["fsdp", "fsdp2"]:
|
||||
from verl.workers.fsdp_workers import RewardModelWorker
|
||||
elif config.reward_model.strategy == "megatron":
|
||||
from verl.workers.megatron_workers import RewardModelWorker
|
||||
else:
|
||||
raise NotImplementedError
|
||||
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
||||
mapping[Role.RewardModel] = global_pool_id
|
||||
|
||||
# use reference model
|
||||
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
||||
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
||||
mapping[Role.RefPolicy] = global_pool_id
|
||||
|
||||
reward_kwargs = {"max_resp_len": config.data.max_response_length, "overlong_buffer_cfg": config.reward_model.overlong_buffer}
|
||||
cfg_reward_kwargs = config.reward_model.get("reward_kwargs", {})
|
||||
reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **(merge_dict(reward_kwargs, cfg_reward_kwargs)))
|
||||
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs)
|
||||
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
||||
|
||||
from verl.utils.dataset.rl_dataset import collate_fn
|
||||
|
||||
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
|
||||
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
|
||||
train_sampler = create_rl_sampler(config.data, train_dataset)
|
||||
trainer = RayEntropyTrainer(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
role_worker_mapping=role_worker_mapping,
|
||||
resource_pool_manager=resource_pool_manager,
|
||||
ray_worker_group_cls=ray_worker_group_cls,
|
||||
reward_fn=reward_fn,
|
||||
val_reward_fn=val_reward_fn,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
collate_fn=collate_fn,
|
||||
train_sampler=train_sampler,
|
||||
)
|
||||
trainer.init_workers()
|
||||
trainer.fit()
|
||||
|
||||
|
||||
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
|
||||
"""Create a dataset.
|
||||
|
||||
Arguments:
|
||||
data_config: The data config.
|
||||
tokenizer (Tokenizer): The tokenizer.
|
||||
processor (Processor): The processor.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset): The dataset.
|
||||
"""
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from verl.utils.dataset.rl_dataset import RLHFDataset
|
||||
|
||||
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
|
||||
from verl.utils.import_utils import load_extern_type
|
||||
|
||||
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
|
||||
if not issubclass(dataset_cls, Dataset):
|
||||
raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset")
|
||||
else:
|
||||
dataset_cls = RLHFDataset
|
||||
print(f"Using dataset class: {dataset_cls.__name__}")
|
||||
|
||||
dataset = dataset_cls(
|
||||
data_files=data_paths,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
config=data_config,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def create_rl_sampler(data_config, dataset):
|
||||
"""Create a sampler for the dataset.
|
||||
|
||||
Arguments:
|
||||
data_config: The data config.
|
||||
dataset (Dataset): The dataset.
|
||||
|
||||
Returns:
|
||||
sampler (Sampler): The sampler.
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import RandomSampler, SequentialSampler
|
||||
|
||||
# use sampler for better ckpt resume
|
||||
if data_config.shuffle:
|
||||
train_dataloader_generator = torch.Generator()
|
||||
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
|
||||
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
|
||||
else:
|
||||
sampler = SequentialSampler(data_source=dataset)
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -62,6 +62,13 @@ actor_rollout_ref:
|
||||
ppo_epochs: 1
|
||||
data_loader_seed: null
|
||||
shuffle: False
|
||||
policy_loss: # policy loss config
|
||||
loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov from https://arxiv.org/abs/2505.22617
|
||||
clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss
|
||||
clip_cov_lb: 1.0 # Lower bound for clip-cov loss
|
||||
clip_cov_ub: 5.0 # Upper bound for clip-cov loss
|
||||
kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss
|
||||
ppo_kl_coef: 0.1 # KL divergence penalty coefficient
|
||||
optim:
|
||||
optimizer: adam
|
||||
lr: 1e-6
|
||||
|
@ -177,6 +177,27 @@ actor_rollout_ref:
|
||||
# Upper bound for asymmetric clipping (used in dual-clip PPO)
|
||||
clip_ratio_high: 0.2
|
||||
|
||||
# policy loss config
|
||||
policy_loss:
|
||||
|
||||
# Loss function mode: vanilla / clip-cov / kl-cov from https://arxiv.org/abs/2505.22617
|
||||
loss_mode: "vanilla"
|
||||
|
||||
# Ratio of tokens to be clipped for clip-cov loss
|
||||
clip_cov_ratio: 0.0002
|
||||
|
||||
# Lower bound for clip-cov loss
|
||||
clip_cov_lb: 1.0
|
||||
|
||||
# Upper bound for clip-cov loss
|
||||
clip_cov_ub: 5.0
|
||||
|
||||
# Ratio of tokens to be applied kl penalty for kl-cov loss
|
||||
kl_cov_ratio: 0.0002
|
||||
|
||||
# KL divergence penalty coefficient
|
||||
ppo_kl_coef: 0.1
|
||||
|
||||
# Constant C in Dual-clip PPO; clips when advantage < -C
|
||||
clip_ratio_c: 3.0
|
||||
|
||||
|
@ -28,6 +28,33 @@ import torch
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
POLICY_LOSS_REGISTRY = {}
|
||||
|
||||
|
||||
def register_policy_loss(name):
|
||||
def decorator(func):
|
||||
POLICY_LOSS_REGISTRY[name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_policy_loss_fn(name):
|
||||
"""Get the policy loss with a given name.
|
||||
|
||||
Args:
|
||||
name: `(str)`
|
||||
The name of the policy loss.
|
||||
|
||||
Returns:
|
||||
`(callable)`: The policy loss function.
|
||||
"""
|
||||
loss_name = name
|
||||
if loss_name not in POLICY_LOSS_REGISTRY:
|
||||
raise ValueError(f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}")
|
||||
return POLICY_LOSS_REGISTRY[loss_name]
|
||||
|
||||
|
||||
ADV_ESTIMATOR_REGISTRY = {}
|
||||
|
||||
|
||||
@ -596,6 +623,159 @@ def compute_policy_loss(
|
||||
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
|
||||
|
||||
|
||||
@register_policy_loss("clip_cov")
|
||||
def compute_policy_loss_clip_cov(
|
||||
old_log_prob,
|
||||
log_prob,
|
||||
advantages,
|
||||
response_mask,
|
||||
loss_agg_mode="token-mean",
|
||||
config=None,
|
||||
):
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
|
||||
Adapted from
|
||||
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
|
||||
|
||||
Args:
|
||||
old_log_prob (torch.Tensor):
|
||||
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
|
||||
log_prob (torch.Tensor):
|
||||
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
|
||||
advantages (torch.Tensor):
|
||||
Advantage estimates for each action, shape (batch_size, response_length).
|
||||
response_mask (torch.Tensor):
|
||||
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
|
||||
cliprange (float, optional):
|
||||
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
|
||||
Defaults to None (must be provided).
|
||||
cliprange_low (float, optional):
|
||||
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
|
||||
cliprange_high (float, optional):
|
||||
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
|
||||
loss_agg_mode (str, optional):
|
||||
Aggregation mode for `agg_loss`. Defaults to "token-mean".
|
||||
clip_cvo_ratio (float, optional):
|
||||
Ratio for clipping the covariance. Defaults to 0.0002.
|
||||
clip_cov_lb (float, optional):
|
||||
Lower bound for clipping covariance. Defaults to 1.0.
|
||||
clip_cov_ub (float, optional):
|
||||
Upper bound for clipping covariance. Defaults to 5.0.
|
||||
"""
|
||||
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
|
||||
cliprange = config.clip_ratio
|
||||
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
|
||||
cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange
|
||||
clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0
|
||||
clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0
|
||||
|
||||
assert clip_cov_ratio > 0, "clip_ratio should be larger than 0."
|
||||
|
||||
negative_approx_kl = log_prob - old_log_prob
|
||||
ratio = torch.exp(negative_approx_kl)
|
||||
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
|
||||
|
||||
pg_losses1 = -advantages * ratio
|
||||
|
||||
if cliprange_low is None:
|
||||
cliprange_low = cliprange
|
||||
if cliprange_high is None:
|
||||
cliprange_high = cliprange
|
||||
|
||||
corr = torch.ones_like(advantages)
|
||||
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
|
||||
clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)
|
||||
|
||||
cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * (log_prob - verl_F.masked_mean(log_prob.detach(), response_mask))
|
||||
cov_all[response_mask == 0] = -torch.inf
|
||||
cov_all[clip_by_origin] = -torch.inf
|
||||
|
||||
clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)
|
||||
top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)
|
||||
top_k_idx = torch.nonzero(top_k_idx)
|
||||
|
||||
if len(top_k_idx) > 0:
|
||||
perm = torch.randperm(len(top_k_idx))
|
||||
top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]
|
||||
else:
|
||||
top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)
|
||||
|
||||
corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0
|
||||
|
||||
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
|
||||
|
||||
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
|
||||
|
||||
|
||||
@register_policy_loss("kl_cov")
|
||||
def compute_policy_loss_kl_cov(
|
||||
old_log_prob,
|
||||
log_prob,
|
||||
advantages,
|
||||
response_mask,
|
||||
loss_agg_mode="token-mean",
|
||||
config=None,
|
||||
):
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
|
||||
Adapted from
|
||||
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
|
||||
|
||||
Args:
|
||||
old_log_prob (torch.Tensor):
|
||||
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
|
||||
log_prob (torch.Tensor):
|
||||
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
|
||||
advantages (torch.Tensor):
|
||||
Advantage estimates for each action, shape (batch_size, response_length).
|
||||
response_mask (torch.Tensor):
|
||||
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
|
||||
loss_agg_mode (str, optional):
|
||||
Aggregation mode for `agg_loss`. Defaults to "token-mean".
|
||||
kl_cov_ratio (float, optional):
|
||||
Ratio for selecting the top-k covariance values. Defaults to 0.0002.
|
||||
ppo_kl_coef (float, optional):
|
||||
Coefficient for the KL penalty term in the loss. Defaults to 1.
|
||||
"""
|
||||
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
|
||||
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0
|
||||
|
||||
assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0."
|
||||
|
||||
negative_approx_kl = log_prob - old_log_prob
|
||||
abs_kl = negative_approx_kl.abs()
|
||||
ratio = torch.exp(negative_approx_kl)
|
||||
ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)
|
||||
pg_losses1 = -advantages * ratio
|
||||
pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl
|
||||
pg_losses = pg_losses1
|
||||
|
||||
all_valid = response_mask > 0
|
||||
all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]
|
||||
all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()
|
||||
all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()
|
||||
|
||||
k = min(kl_cov_ratio, len(all_valid_adv))
|
||||
|
||||
if k != 0:
|
||||
cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())
|
||||
k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))
|
||||
large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices
|
||||
|
||||
if len(large_cov_idxs) != 0:
|
||||
large_cov_idxs = all_valid_idx[large_cov_idxs]
|
||||
pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]]
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
|
||||
|
||||
|
||||
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
|
||||
"""Compute categorical entropy loss (For backward compatibility)
|
||||
|
||||
|
@ -655,9 +655,11 @@ class RayPPOTrainer:
|
||||
sample_scores.extend(scores)
|
||||
|
||||
reward_extra_infos_dict["reward"].extend(scores)
|
||||
print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}")
|
||||
if "reward_extra_info" in result:
|
||||
for key, lst in result["reward_extra_info"].items():
|
||||
reward_extra_infos_dict[key].extend(lst)
|
||||
print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}")
|
||||
|
||||
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
|
||||
|
||||
|
@ -83,6 +83,7 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No
|
||||
from . import search_r1_like_qa_em
|
||||
|
||||
res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
|
||||
|
||||
|
1051
verl/utils/reward_score/entropy_math/__init__.py
Normal file
1051
verl/utils/reward_score/entropy_math/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
342
verl/utils/reward_score/entropy_math/grader.py
Normal file
342
verl/utils/reward_score/entropy_math/grader.py
Normal file
@ -0,0 +1,342 @@
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE
|
||||
|
||||
# Copyright (c) 2023 OpenAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# Copyright (c) 2021 Dan Hendrycks
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# Copyright 2024 PRIME team and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
|
||||
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
|
||||
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
||||
- https://github.com/openai/prm800k
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import re
|
||||
from math import isclose
|
||||
from typing import Union
|
||||
|
||||
# sympy related
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
# verl related
|
||||
from verl.utils.py_functional import timeout_limit
|
||||
|
||||
|
||||
def is_digit(s):
|
||||
try:
|
||||
if "{,}" in str(s):
|
||||
num = float(str(s).replace("{,}", ""))
|
||||
return True, num
|
||||
|
||||
num = float(str(s).replace(",", ""))
|
||||
return True, num
|
||||
except ValueError:
|
||||
return False, None
|
||||
|
||||
|
||||
def normalize(answer, pi) -> str:
|
||||
# checking if answer is $<number> and removing $ in that case to compare
|
||||
if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)):
|
||||
return answer[1:]
|
||||
|
||||
# checking if answer is <number>% or <number>\\% and removing %
|
||||
if isinstance(answer, str) and (bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))):
|
||||
return answer.replace("\\%", "").replace("%", "")
|
||||
|
||||
# handle base
|
||||
answer = handle_base(answer)
|
||||
|
||||
# handle pi
|
||||
answer = handle_pi(answer, pi)
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
def handle_base(x) -> str:
|
||||
if isinstance(x, str) and "_" in x:
|
||||
# Due to base
|
||||
x = x.split("_")[0]
|
||||
x = float(x)
|
||||
return int(x)
|
||||
return x
|
||||
|
||||
|
||||
def handle_pi(string, pi):
|
||||
if isinstance(string, str) and "\pi" in string:
|
||||
# Find the first occurrence of "\pi"
|
||||
idx = string.find("\pi")
|
||||
|
||||
# Iterate over the string and find all occurrences of "\pi" with a valid previous character
|
||||
while idx != -1:
|
||||
if idx > 0 and string[idx - 1].isdigit():
|
||||
# Replace "\pi" with "*math.pi" if the previous character is a digit
|
||||
string = string[:idx] + f"*{pi}" + string[idx + 3 :]
|
||||
else:
|
||||
# Replace "\pi" with "1*math.pi" if the previous character is not a digit
|
||||
string = string[:idx] + f"1*{pi}" + string[idx + 3 :]
|
||||
|
||||
# Find the next occurrence of "\pi"
|
||||
idx = string.find("\pi", idx + 1)
|
||||
|
||||
# Evaluate the expression using eval() function
|
||||
with contextlib.suppress(Exception):
|
||||
string = eval(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
tolerance: float = 1e-4,
|
||||
timeout: float = 10.0,
|
||||
pi: float = math.pi,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
|
||||
prediction = normalize(prediction, pi)
|
||||
reference = normalize(reference, pi)
|
||||
|
||||
if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
|
||||
prediction = prediction[:1000]
|
||||
|
||||
# 0. string comparison
|
||||
if isinstance(prediction, str) and isinstance(reference, str):
|
||||
if prediction.strip().lower() == reference.strip().lower():
|
||||
return True
|
||||
if prediction.replace(" ", "") == reference.replace(" ", ""):
|
||||
return True
|
||||
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction)[0] and is_digit(reference)[0]:
|
||||
prediction = is_digit(prediction)[1]
|
||||
reference = is_digit(reference)[1]
|
||||
# number questions
|
||||
gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if isclose(item, prediction, rel_tol=tolerance):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## deal with [], (), {}
|
||||
prediction = format_intervals(prediction)
|
||||
|
||||
pred_str, ref_str = prediction, reference
|
||||
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
|
||||
pred_str = pred_str.strip("[]()")
|
||||
ref_str = ref_str.strip("[]()")
|
||||
for s in ["{", "}", "(", ")"]:
|
||||
ref_str = ref_str.replace(s, "")
|
||||
pred_str = pred_str.replace(s, "")
|
||||
if pred_str == ref_str:
|
||||
return True
|
||||
|
||||
## [a, b] vs. [c, d], return a==c and b==d
|
||||
if prediction and reference and prediction[0] in "([" and prediction[-1] in ")]" and prediction[0] == reference[0] and prediction[-1] == reference[-1]:
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]):
|
||||
return True
|
||||
|
||||
if "," in prediction and "," in reference:
|
||||
pred_parts = [item.strip() for item in prediction.split(",")]
|
||||
ref_parts = [item.strip() for item in reference.split(",")]
|
||||
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
return bool(all([math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) for i in range(len(pred_parts))]))
|
||||
|
||||
# if we have point == tuple of values
|
||||
if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
|
||||
pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]):
|
||||
return True
|
||||
|
||||
# if reference is a matrix
|
||||
if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"):
|
||||
try:
|
||||
pred_matrix = parse_expr(prediction)
|
||||
ref_matrix_items = reference.split()[1:-1:2]
|
||||
if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"):
|
||||
if isinstance(eval(prediction), list):
|
||||
try:
|
||||
pred_matrix = eval(prediction)
|
||||
# ref_matrix_items = reference.split()[1:-1:2]
|
||||
ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip("\\end{pmatrix}").rstrip("\end{pmatrix}") # noqa: B005
|
||||
ref_matrix_items = ref_matrix_items.split("\\")
|
||||
ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items]
|
||||
if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return symbolic_equal(prediction, reference, tolerance, timeout)
|
||||
|
||||
|
||||
def symbolic_equal(a, b, tolerance, timeout=10.0):
|
||||
def _parse(s):
|
||||
for f in [parse_expr, parse_latex]:
|
||||
try:
|
||||
with timeout_limit(seconds=timeout):
|
||||
return f(s)
|
||||
except TimeoutError:
|
||||
print(f"Parsing timed out for {s}")
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
with timeout_limit(seconds=timeout):
|
||||
if simplify(a - b) == 0:
|
||||
return True
|
||||
except TimeoutError:
|
||||
print(f"Simplification timed out for {a} - {b}")
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
with timeout_limit(seconds=timeout):
|
||||
if isclose(N(a), N(b), rel_tol=tolerance):
|
||||
return True
|
||||
except TimeoutError:
|
||||
print(f"Numerical evaluation timed out for {a}, {b}")
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def format_intervals(prediction):
|
||||
patterns = {
|
||||
"Interval(": r"^Interval\((.*)\)$",
|
||||
"Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
|
||||
"Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
|
||||
"Interval.open(": r"^Interval\.open\((.*)\)$",
|
||||
}
|
||||
|
||||
for key, pattern in patterns.items():
|
||||
match = re.match(pattern, prediction)
|
||||
if match:
|
||||
inner_content = match.group(1)
|
||||
|
||||
if key == "Interval(": # Intarval(a, b) == [a, b]
|
||||
return f"[{inner_content}]"
|
||||
elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
|
||||
return f"[{inner_content})"
|
||||
elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
|
||||
return f"({inner_content}]"
|
||||
elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
|
||||
return f"({inner_content})"
|
||||
|
||||
return prediction
|
191
verl/utils/reward_score/entropy_math/math_normalize.py
Normal file
191
verl/utils/reward_score/entropy_math/math_normalize.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Copyright 2024 PRIME team and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Copyright (c) 2021 Dan Hendrycks
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
"""
|
||||
This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
|
||||
|
||||
From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def normalize_answer(answer: Optional[str]) -> Optional[str]:
|
||||
if answer is None:
|
||||
return None
|
||||
answer = answer.strip()
|
||||
try:
|
||||
# Remove enclosing `\text{}`.
|
||||
m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
|
||||
if m is not None:
|
||||
answer = m.group("text").strip()
|
||||
return _strip_string(answer)
|
||||
except: # noqa: E722
|
||||
return answer
|
||||
|
||||
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except: # noqa: E722
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
a = int(a)
|
||||
b = int(b)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except: # noqa: E722
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string):
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _strip_string(string):
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
@ -28,7 +28,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty
|
||||
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
|
||||
from verl.utils.debug import GPUMemoryLogger
|
||||
from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available
|
||||
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
|
||||
@ -399,6 +399,9 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
calculate_entropy = True
|
||||
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy)
|
||||
|
||||
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
|
||||
|
||||
if self.config.policy_loss.loss_mode == "vanilla":
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
@ -410,6 +413,9 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
clip_ratio_c=clip_ratio_c,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
)
|
||||
else:
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config)
|
||||
|
||||
if entropy_coeff != 0:
|
||||
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
@ -38,7 +38,7 @@ from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty
|
||||
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
|
||||
from verl.utils.debug import GPUMemoryLogger
|
||||
from verl.utils.debug.profile import Profiler
|
||||
from verl.utils.device import get_device_id, get_torch_device
|
||||
@ -352,10 +352,16 @@ class MegatronPPOActor(BasePPOActor):
|
||||
old_log_prob = data["old_log_probs"]
|
||||
advantages = data["advantages"]
|
||||
|
||||
clip_ratio = meta_info["clip_ratio"]
|
||||
clip_ratio = self.config.clip_ratio
|
||||
clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
|
||||
clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
|
||||
clip_ratio_c = meta_info["clip_ratio_c"]
|
||||
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
|
||||
entropy_coeff = self.config.entropy_coeff
|
||||
loss_agg_mode = self.config.loss_agg_mode
|
||||
|
||||
loss_mode = self.config.get("loss_mode", "vanilla")
|
||||
|
||||
if self.config.policy_loss.loss_mode == "vanilla":
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
@ -367,6 +373,9 @@ class MegatronPPOActor(BasePPOActor):
|
||||
clip_ratio_c=clip_ratio_c,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
)
|
||||
else:
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config)
|
||||
policy_loss = pg_loss
|
||||
if calculate_entropy:
|
||||
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
|
||||
|
Reference in New Issue
Block a user