mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
[megatron, fsdp, doc] feat: implement GPG loss. Add GPG advantage estimator implementation. (#2057)
…and integrate into PPO training scripts and core algorithms ### Checklist Before Starting - [x] Searched for similar PR(s). - [x] Checked PR Title format - In format of: [modules] type: Title - modules are in `fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data` - type is in `feat, fix, refactor, chore` - can involve multiple modules, seperated by `,` or space, like `[megatron, fsdp, doc] feat: xxx` ### What does this PR do? Implement GPG loss (GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning) which can achieve comparable performance in less training time. ### Test some training records:   ### Specific Changes > List the specific changes. Add doc of GPG in docs/algo/gpg.md Add the addvantage estimation function of gpg in verl/trainer/ppo/core_algos.py. Add compute_gpg_loss function of gpg in verl/ trainer/ppo/core_algos.py. Add a conditional branch to determine whether to use the GPG loss in verl/workers/actor/dp_actor.py and megatron_actor.py Add example scripts of GPG in examples/gpg_trainer. ### Usage Example ```shell # Add code snippet or script demonstrating how to use this bash examples/gpg_trainer/run_qwen2-7b_math.sh ``` ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title `description` if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). --------- Co-authored-by: H <linhaibin.eric@gmail.com>
This commit is contained in:
@ -36,6 +36,8 @@ Refer to the table below to reproduce RL training from different pre-trained che
|
||||
| NVIDIA GPU | Mixtral-8x22B-Instruct-v0.1 | Instruct model | 83.7 | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/) |
|
||||
| NVIDIA GPU | Mixtral-8x22B-Instruct-v0.1 | RLOO (Megatron) | 92.3 | [wandb](https://api.wandb.ai/links/ppo_dev/sbuiuf2d) |
|
||||
| NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | SPIN | 92 | [script](https://github.com/volcengine/verl/tree/main/recipe/spin/README.md) |
|
||||
| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GPG | 88 | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/ab86c4va) |
|
||||
| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GPG (Megatron) | 88 | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math_megatron.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/yy8bheu8) |
|
||||
| NVIDIA GPU | Qwen/Qwen2.5-VL-7B-Instruct | GRPO (Megatron) | 65.4 (GEO3k) | [script](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh), [wandb](https://api.wandb.ai/links/megatron-core-moe-dev/1yngvkek) |
|
||||
| AMD MI300 | deepseek-ai/deepseek-llm-7b-chat | PPO | 70.5 [1] | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/ppo_run_deepseek7b_llm.log) |
|
||||
| AMD MI300 | deepseek-ai/deepseek-llm-7b-chat | GRPO | 71.4 [1] | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/grpo_run_deepseek7b_llm.log) |
|
||||
|
36
docs/algo/gpg.md
Normal file
36
docs/algo/gpg.md
Normal file
@ -0,0 +1,36 @@
|
||||
# GPG: Group Policy Gradient
|
||||
|
||||
Last updated: 07/03/2025.
|
||||
|
||||
Group Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning
|
||||
](https://arxiv.org/abs/2504.02546).
|
||||
|
||||
## Key Components
|
||||
- Use a corrected advantage function to improve policy gradient accuracy and training efficiency.
|
||||
- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO)
|
||||
|
||||
## Configuration
|
||||
To configure GPG within the framework, use the following YAML settings.
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
adv_estimator: gpg
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
policy_loss:
|
||||
loss_mode: "gpg"
|
||||
```
|
||||
|
||||
## Advanced Extensions
|
||||
GPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance.
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
adv_estimator: gpg
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
use_kl_loss: True # enable kl regularization
|
||||
kl_loss_coef: 0.01
|
||||
policy_loss:
|
||||
loss_mode: "gpg"
|
||||
```
|
@ -74,6 +74,7 @@ verl is fast with:
|
||||
algo/entropy.md
|
||||
algo/opo.md
|
||||
algo/baseline.md
|
||||
algo/gpg.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
34
examples/gpg_trainer/gpg.md
Normal file
34
examples/gpg_trainer/gpg.md
Normal file
@ -0,0 +1,34 @@
|
||||
# GPG: Group Policy Gradient
|
||||
|
||||
Group Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning
|
||||
](https://arxiv.org/abs/2504.02546).
|
||||
|
||||
## Key Components
|
||||
- Use a corrected advantage function to improve policy gradient accuracy and training efficiency.
|
||||
- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO)
|
||||
|
||||
## Configuration
|
||||
To configure GPG within the framework, use the following YAML settings.
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
adv_estimator: gpg
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
policy_loss:
|
||||
loss_mode: "gpg"
|
||||
```
|
||||
|
||||
## Advanced Extensions
|
||||
GPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance.
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
adv_estimator: gpg
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
use_kl_loss: True # enable kl regularization
|
||||
kl_loss_coef: 0.01
|
||||
policy_loss:
|
||||
loss_mode: "gpg"
|
||||
```
|
52
examples/gpg_trainer/run_qwen2-7b_math.sh
Executable file
52
examples/gpg_trainer/run_qwen2-7b_math.sh
Executable file
@ -0,0 +1,52 @@
|
||||
set -x
|
||||
|
||||
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
|
||||
# export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
|
||||
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
|
||||
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
|
||||
math_train_path=$HOME/data/math/train.parquet
|
||||
math_test_path=$HOME/data/math/test.parquet
|
||||
|
||||
train_files="['$gsm8k_train_path', '$math_train_path']"
|
||||
test_files="['$gsm8k_test_path', '$math_test_path']"
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=gpg \
|
||||
data.train_files="$train_files" \
|
||||
data.val_files="$test_files" \
|
||||
data.train_batch_size=1024 \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
|
||||
actor_rollout_ref.actor.use_kl_loss=False \
|
||||
actor_rollout_ref.actor.policy_loss.loss_mode=gpg \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.n=5 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name='verl_gpg_example_gsm8k_math' \
|
||||
trainer.experiment_name='qwen2_7b_function_rm' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=20 \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=15 $@
|
54
examples/gpg_trainer/run_qwen2-7b_math_megatron.sh
Executable file
54
examples/gpg_trainer/run_qwen2-7b_math_megatron.sh
Executable file
@ -0,0 +1,54 @@
|
||||
set -x
|
||||
|
||||
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
|
||||
# export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
|
||||
|
||||
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
|
||||
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
|
||||
math_train_path=$HOME/data/math/train.parquet
|
||||
math_test_path=$HOME/data/math/test.parquet
|
||||
|
||||
train_files="['$gsm8k_train_path', '$math_train_path']"
|
||||
test_files="['$gsm8k_test_path', '$math_test_path']"
|
||||
|
||||
python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
--config-name='ppo_megatron_trainer.yaml'\
|
||||
algorithm.adv_estimator=gpg \
|
||||
data.train_files="$train_files" \
|
||||
data.val_files="$test_files" \
|
||||
data.train_batch_size=1024 \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.actor.policy_loss.loss_mode=gpg \
|
||||
actor_rollout_ref.actor.use_kl_loss=False \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.n=5 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \
|
||||
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name='verl_gpg_example_gsm8k_math' \
|
||||
trainer.experiment_name='qwen2_7b_megatron' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=20 \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=15 $@
|
@ -64,7 +64,7 @@ actor_rollout_ref:
|
||||
data_loader_seed: null
|
||||
shuffle: False
|
||||
policy_loss: # policy loss config
|
||||
loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov from https://arxiv.org/abs/2505.22617
|
||||
loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov / gpg from https://arxiv.org/abs/2505.22617,
|
||||
clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss
|
||||
clip_cov_lb: 1.0 # Lower bound for clip-cov loss
|
||||
clip_cov_ub: 5.0 # Upper bound for clip-cov loss
|
||||
|
@ -187,7 +187,7 @@ actor_rollout_ref:
|
||||
# policy loss config
|
||||
policy_loss:
|
||||
|
||||
# Loss function mode: vanilla / clip-cov / kl-cov from https://arxiv.org/abs/2505.22617
|
||||
# Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617
|
||||
loss_mode: "vanilla"
|
||||
|
||||
# Ratio of tokens to be clipped for clip-cov loss
|
||||
|
@ -114,6 +114,7 @@ class AdvantageEstimator(str, Enum):
|
||||
RLOO = "rloo"
|
||||
OPO = "opo"
|
||||
GRPO_PASSK = "grpo_passk"
|
||||
GPG = "gpg"
|
||||
|
||||
|
||||
class AdaptiveKLController:
|
||||
@ -555,6 +556,68 @@ def compute_remax_outcome_advantage(
|
||||
return advantages, returns
|
||||
|
||||
|
||||
@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg")
|
||||
def compute_gpg_outcome_advantage(
|
||||
token_level_rewards: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
index: np.ndarray,
|
||||
epsilon: float = 1e-6,
|
||||
f_norm: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
config=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Compute advantage for GPG, operating only on Outcome reward
|
||||
(with only one scalar reward for each response).
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
index: `(np.ndarray)`
|
||||
shape: (bs,)
|
||||
epsilon: (float)
|
||||
f_norm: (float)
|
||||
alpha: (float)
|
||||
config: (dict) algorithm config
|
||||
|
||||
Returns:
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
Returns: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
"""
|
||||
scores = token_level_rewards.sum(dim=-1)
|
||||
|
||||
id2score = defaultdict(list)
|
||||
id2mean = {}
|
||||
id2std = {}
|
||||
|
||||
with torch.no_grad():
|
||||
bsz = scores.shape[0]
|
||||
m = torch.count_nonzero(scores)
|
||||
alpha = bsz / m.clamp(min=1)
|
||||
|
||||
for i in range(bsz):
|
||||
id2score[index[i]].append(scores[i])
|
||||
|
||||
for idx in id2score:
|
||||
if len(id2score[idx]) == 1:
|
||||
id2mean[idx] = torch.tensor(0.0)
|
||||
id2std[idx] = torch.tensor(1.0)
|
||||
elif len(id2score[idx]) > 1:
|
||||
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
||||
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
|
||||
else:
|
||||
raise ValueError(f"no score in prompt index: {idx}")
|
||||
for i in range(bsz):
|
||||
scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)
|
||||
scores = scores.unsqueeze(-1) * response_mask
|
||||
|
||||
return scores, scores
|
||||
|
||||
|
||||
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
|
||||
kl = old_log_prob - ref_log_prob
|
||||
return token_level_scores - kl * kl_ratio
|
||||
@ -671,6 +734,27 @@ def compute_policy_loss(
|
||||
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
|
||||
|
||||
|
||||
@register_policy_loss("gpg")
|
||||
def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None):
|
||||
"""Adapted from
|
||||
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
|
||||
Args:
|
||||
log_prob: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
return:
|
||||
pg_loss: `a scalar torch.Tensor`
|
||||
policy gradient loss computed via GPG
|
||||
"""
|
||||
pg_losses = -log_prob * advantages
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
|
||||
|
||||
|
||||
@register_policy_loss("clip_cov")
|
||||
def compute_policy_loss_clip_cov(
|
||||
old_log_prob,
|
||||
|
@ -362,6 +362,7 @@ class RayPPOTrainer:
|
||||
AdvantageEstimator.RLOO,
|
||||
AdvantageEstimator.OPO,
|
||||
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
|
||||
AdvantageEstimator.GPG,
|
||||
]:
|
||||
self.use_critic = False
|
||||
else:
|
||||
|
@ -503,10 +503,16 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
clip_ratio_c=clip_ratio_c,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
)
|
||||
|
||||
else:
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
|
||||
old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
advantages=advantages,
|
||||
response_mask=advantages,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
if entropy_coeff != 0:
|
||||
|
@ -396,6 +396,7 @@ class MegatronPPOActor(BasePPOActor):
|
||||
# compute policy loss
|
||||
log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous()
|
||||
ret_entropy = None
|
||||
stats = {}
|
||||
if not forward_only:
|
||||
old_log_prob = data["old_log_probs"]
|
||||
advantages = data["advantages"]
|
||||
@ -403,11 +404,12 @@ class MegatronPPOActor(BasePPOActor):
|
||||
clip_ratio = self.config.clip_ratio
|
||||
clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
|
||||
clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
|
||||
|
||||
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
|
||||
entropy_coeff = self.config.entropy_coeff
|
||||
loss_agg_mode = self.config.loss_agg_mode
|
||||
|
||||
loss_mode = self.config.get("loss_mode", "vanilla")
|
||||
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
|
||||
|
||||
if self.config.policy_loss.loss_mode == "vanilla":
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
|
||||
@ -421,12 +423,28 @@ class MegatronPPOActor(BasePPOActor):
|
||||
clip_ratio_c=clip_ratio_c,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
)
|
||||
|
||||
else:
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
|
||||
old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
advantages=advantages,
|
||||
response_mask=response_mask,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
stats.update(
|
||||
{
|
||||
"actor/pg_loss": pg_loss.detach().item(),
|
||||
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
|
||||
"actor/ppo_kl": ppo_kl.detach().item(),
|
||||
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
|
||||
}
|
||||
)
|
||||
policy_loss = pg_loss
|
||||
|
||||
if calculate_entropy:
|
||||
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
|
||||
if not forward_only:
|
||||
@ -436,7 +454,6 @@ class MegatronPPOActor(BasePPOActor):
|
||||
else:
|
||||
ret_entropy = entropy
|
||||
|
||||
stats = {}
|
||||
if forward_only:
|
||||
policy_loss = torch.tensor(1.0, device=device)
|
||||
else:
|
||||
@ -451,14 +468,6 @@ class MegatronPPOActor(BasePPOActor):
|
||||
metrics["actor/kl_coef"] = self.config.kl_loss_coef
|
||||
|
||||
# return loss and stats
|
||||
stats.update(
|
||||
{
|
||||
"actor/pg_loss": pg_loss.detach().item(),
|
||||
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
|
||||
"actor/ppo_kl": ppo_kl.detach().item(),
|
||||
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
|
||||
}
|
||||
)
|
||||
|
||||
append_to_dict(metrics, stats)
|
||||
return policy_loss, [metrics, ret_entropy]
|
||||
|
Reference in New Issue
Block a user