mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
fix: misleading eos_mask->response_mask (#878)
https://github.com/volcengine/verl/pull/868#discussion_r2024416560
This commit is contained in:
@ -17,7 +17,7 @@ import verl
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config):
|
||||
def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):
|
||||
# calculate rloo reward on different reward sources, and sum again
|
||||
def masked_rloo(reward_tensor_original, mask_tensor):
|
||||
reward_tensor = reward_tensor_original.clone()
|
||||
@ -44,13 +44,13 @@ def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor,
|
||||
|
||||
if 'rm_scores' in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.:
|
||||
reward_tensor = data.batch['rm_scores']
|
||||
reward_mask = eos_mask.bool()
|
||||
reward_mask = response_mask.bool()
|
||||
|
||||
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)
|
||||
|
||||
if 'acc' in data.batch.keys() and config.algorithm.reward_gt_coef != 0.:
|
||||
reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32)
|
||||
reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool)
|
||||
reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
|
||||
reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
|
||||
|
||||
prompt_ids = data.batch['prompts']
|
||||
prompt_length = prompt_ids.shape[-1]
|
||||
@ -67,25 +67,25 @@ def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor,
|
||||
|
||||
final_reward_tensor = sum(reward_tensors)
|
||||
|
||||
returns = (final_reward_tensor * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
||||
returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
||||
|
||||
advantages = returns.clone()
|
||||
advantages = verl_F.masked_whiten(advantages, eos_mask)
|
||||
advantages = verl_F.masked_whiten(advantages, response_mask)
|
||||
|
||||
return advantages, returns
|
||||
|
||||
|
||||
def compute_ce_dpo_loss_rm(token_level_scores, acc, eos_mask, beta):
|
||||
cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta).sigmoid()
|
||||
def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
|
||||
cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()
|
||||
cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
|
||||
return cur_dpo_loss
|
||||
|
||||
|
||||
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, eos_mask, beta, bon_mode='none'):
|
||||
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode='none'):
|
||||
# we always assume that the BoN size equals n_samples
|
||||
# mode1: use acc as rm
|
||||
# mode2: use Q as rm
|
||||
cur_Q = (token_level_scores * eos_mask).sum(dim=1) * beta
|
||||
cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta
|
||||
other_Q = torch.zeros_like(cur_Q)
|
||||
for i in range(token_level_scores.shape[0]):
|
||||
if acc[i] > 0:
|
||||
@ -115,11 +115,11 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, eos_mask,
|
||||
return dpo_loss
|
||||
|
||||
|
||||
def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples):
|
||||
def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):
|
||||
dpo_acc = []
|
||||
for start_id in range(0, token_level_scores.shape[0], n_samples):
|
||||
cur_scores = (token_level_scores[start_id:start_id + n_samples] *
|
||||
eos_mask[start_id:start_id + n_samples]).sum(dim=1)
|
||||
response_mask[start_id:start_id + n_samples]).sum(dim=1)
|
||||
|
||||
def get_upper_triangle(tensor_x):
|
||||
diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
|
||||
@ -140,5 +140,5 @@ def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples):
|
||||
return torch.cat(dpo_acc, dim=0).mean()
|
||||
|
||||
|
||||
def compute_dpo_abs_accuracy(token_level_scores, acc, eos_mask, n_samples):
|
||||
return (torch.sign((token_level_scores * eos_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()
|
||||
def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):
|
||||
return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()
|
||||
|
@ -277,7 +277,7 @@ class DataParallelPRIMERewardModel:
|
||||
prompt_ids = data['prompts']
|
||||
prompt_length = prompt_ids.shape[-1]
|
||||
|
||||
eos_mask = attention_mask[:, prompt_length:]
|
||||
response_mask = attention_mask[:, prompt_length:]
|
||||
|
||||
rm_score, q = self._forward_micro_batch(data, prompt_length)
|
||||
|
||||
@ -285,14 +285,14 @@ class DataParallelPRIMERewardModel:
|
||||
q_lst.append(q.detach())
|
||||
|
||||
if self.config.model.loss_type == 'ce':
|
||||
dpo_loss = compute_ce_dpo_loss_rm(q, acc, eos_mask=eos_mask, beta=beta)
|
||||
dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)
|
||||
elif self.config.model.loss_type == 'dpo':
|
||||
# the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update.
|
||||
dpo_loss = compute_detach_dpo_loss_rm(q,
|
||||
acc,
|
||||
Q_bc=data['Q_bc'],
|
||||
acc_bc=data['acc_bc'],
|
||||
eos_mask=eos_mask,
|
||||
response_mask=response_mask,
|
||||
beta=beta)
|
||||
elif self.config.model.loss_type == 'bon_acc':
|
||||
# change the original distribution of each sample to BoN distribution, then update reward model
|
||||
@ -300,7 +300,7 @@ class DataParallelPRIMERewardModel:
|
||||
acc,
|
||||
Q_bc=data['Q_bc'],
|
||||
acc_bc=data['acc_bc'],
|
||||
eos_mask=eos_mask,
|
||||
response_mask=response_mask,
|
||||
beta=beta,
|
||||
bon_mode='bon_acc')
|
||||
elif self.config.model.loss_type == 'bon_rm':
|
||||
@ -308,7 +308,7 @@ class DataParallelPRIMERewardModel:
|
||||
acc,
|
||||
Q_bc=data['Q_bc'],
|
||||
acc_bc=data['acc_bc'],
|
||||
eos_mask=eos_mask,
|
||||
response_mask=response_mask,
|
||||
beta=beta,
|
||||
bon_mode='bon_rm')
|
||||
else:
|
||||
|
@ -261,11 +261,11 @@ class PRIMERewardModelWorker(Worker):
|
||||
rm_scores, q, metrics = self.rm.compute_rm_score(data=data)
|
||||
|
||||
prompt_length = data.batch['prompts'].shape[-1]
|
||||
eos_mask = data.batch['attention_mask'][:, prompt_length:]
|
||||
response_mask = data.batch['attention_mask'][:, prompt_length:]
|
||||
acc = data.batch['acc']
|
||||
|
||||
dpo_acc = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n'])
|
||||
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, eos_mask, n_samples=data.meta_info['n'])
|
||||
dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info['n'])
|
||||
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
|
||||
|
||||
metrics['reward_model/dpo_acc'] = dpo_acc.detach().item()
|
||||
metrics['reward_model/dpo_acc_abs'] = dpo_acc_abs.detach().item()
|
||||
@ -299,11 +299,14 @@ class PRIMERewardModelWorker(Worker):
|
||||
metrics['rm/lr'] = lr
|
||||
|
||||
prompt_length = data.batch['prompts'].shape[-1]
|
||||
eos_mask = data.batch['attention_mask'][:, prompt_length:]
|
||||
response_mask = data.batch['attention_mask'][:, prompt_length:]
|
||||
acc = data.batch['acc']
|
||||
|
||||
dpo_acc_before = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n'])
|
||||
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, eos_mask, n_samples=data.meta_info['n'])
|
||||
dpo_acc_before = compute_dpo_accuracy(rm_scores,
|
||||
acc,
|
||||
response_mask=response_mask,
|
||||
n_samples=data.meta_info['n'])
|
||||
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
|
||||
|
||||
metrics['reward_model/dpo_acc_before'] = dpo_acc_before.detach().item()
|
||||
metrics['reward_model/dpo_acc_abs_before'] = dpo_acc_abs.detach().item()
|
||||
|
@ -63,7 +63,7 @@ def get_kl_controller(kl_ctrl):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor,
|
||||
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor,
|
||||
gamma: torch.Tensor, lam: torch.Tensor):
|
||||
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
|
||||
|
||||
@ -72,7 +72,7 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
|
||||
shape: (bs, response_length)
|
||||
values: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
|
||||
gamma: `(float)`
|
||||
discounted factor used in RL
|
||||
@ -99,13 +99,13 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
|
||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||
|
||||
returns = advantages + values
|
||||
advantages = verl_F.masked_whiten(advantages, eos_mask)
|
||||
advantages = verl_F.masked_whiten(advantages, response_mask)
|
||||
return advantages, returns
|
||||
|
||||
|
||||
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
|
||||
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
eos_mask: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
epsilon: float = 1e-6):
|
||||
"""
|
||||
@ -114,7 +114,7 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
@ -145,13 +145,13 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
raise ValueError(f"no score in prompt index: {idx}")
|
||||
for i in range(bsz):
|
||||
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
|
||||
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
|
||||
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
|
||||
|
||||
return scores, scores
|
||||
|
||||
|
||||
def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
eos_mask: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
epsilon: float = 1e-6):
|
||||
"""
|
||||
@ -159,7 +159,7 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
@ -190,12 +190,12 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
if response_num > 1:
|
||||
scores[i] = scores[i] * response_num / (response_num -
|
||||
1) - id2mean[index[i]] * response_num / (response_num - 1)
|
||||
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
|
||||
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
|
||||
|
||||
return scores, scores
|
||||
|
||||
|
||||
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, eos_mask: torch.Tensor,
|
||||
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor,
|
||||
gamma: torch.Tensor):
|
||||
"""
|
||||
Compute advantage for REINFORCE++.
|
||||
@ -203,7 +203,7 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
@ -221,16 +221,16 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten
|
||||
running_return = token_level_rewards[:, t] + gamma * running_return
|
||||
returns[:, t] = running_return
|
||||
# Reset after EOS
|
||||
running_return = running_return * eos_mask[:, t]
|
||||
running_return = running_return * response_mask[:, t]
|
||||
|
||||
advantages = verl_F.masked_whiten(returns, eos_mask)
|
||||
advantages = advantages * eos_mask
|
||||
advantages = verl_F.masked_whiten(returns, response_mask)
|
||||
advantages = advantages * response_mask
|
||||
|
||||
return advantages, returns
|
||||
|
||||
|
||||
def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor,
|
||||
eos_mask: torch.Tensor):
|
||||
response_mask: torch.Tensor):
|
||||
"""
|
||||
Compute advantage for ReMax, operating only on Outcome reward
|
||||
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
|
||||
@ -241,7 +241,7 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba
|
||||
shape: (bs, response_length)
|
||||
reward_baselines: `(torch.Tensor)`
|
||||
shape: (bs,)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
@ -254,8 +254,8 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba
|
||||
scores = token_level_rewards.sum(dim=-1)
|
||||
|
||||
with torch.no_grad():
|
||||
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
||||
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
|
||||
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
||||
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * response_mask
|
||||
|
||||
return advantages, returns
|
||||
|
||||
@ -265,7 +265,7 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
|
||||
return token_level_scores - kl * kl_ratio
|
||||
|
||||
|
||||
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, clip_ratio_c=3.0):
|
||||
def compute_policy_loss(old_log_prob, log_prob, advantages, response_mask, cliprange, clip_ratio_c=3.0):
|
||||
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
|
||||
|
||||
Args:
|
||||
@ -275,7 +275,7 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange,
|
||||
shape: (bs, response_length)
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
cliprange: (float)
|
||||
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
|
||||
@ -294,32 +294,33 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange,
|
||||
|
||||
negative_approx_kl = log_prob - old_log_prob
|
||||
ratio = torch.exp(negative_approx_kl)
|
||||
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
|
||||
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
|
||||
|
||||
pg_losses = -advantages * ratio
|
||||
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
|
||||
|
||||
clip_pg_losses1 = torch.max(pg_losses, pg_losses2)
|
||||
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
|
||||
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), response_mask)
|
||||
|
||||
pg_losses3 = -advantages * clip_ratio_c
|
||||
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
|
||||
pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), eos_mask)
|
||||
pg_clipfrac_lower = verl_F.masked_mean(
|
||||
torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), response_mask)
|
||||
# We only apply the dual-clip when the advantage is negative.
|
||||
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
|
||||
|
||||
pg_loss = verl_F.masked_mean(pg_losses, eos_mask)
|
||||
pg_loss = verl_F.masked_mean(pg_losses, response_mask)
|
||||
|
||||
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
|
||||
|
||||
|
||||
def compute_entropy_loss(logits, eos_mask):
|
||||
def compute_entropy_loss(logits, response_mask):
|
||||
"""Compute Categorical entropy loss
|
||||
|
||||
Args:
|
||||
logits: `(torch.Tensor)`
|
||||
shape: (bs, response_length, vocab_size)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
@ -328,11 +329,11 @@ def compute_entropy_loss(logits, eos_mask):
|
||||
"""
|
||||
# compute entropy
|
||||
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
|
||||
return entropy_loss
|
||||
|
||||
|
||||
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
|
||||
def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
|
||||
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
|
||||
|
||||
Args:
|
||||
@ -353,8 +354,8 @@ def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
|
||||
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
|
||||
vf_losses1 = (vpreds - returns)**2
|
||||
vf_losses2 = (vpredclipped - returns)**2
|
||||
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
|
||||
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
|
||||
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
|
||||
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
|
||||
return vf_loss, vf_clipfrac
|
||||
|
||||
|
||||
|
@ -181,7 +181,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
|
||||
advantages, returns = core_algos.compute_gae_advantage_return(
|
||||
token_level_rewards=data.batch['token_level_rewards'],
|
||||
values=data.batch['values'],
|
||||
eos_mask=data.batch['response_mask'],
|
||||
response_mask=data.batch['response_mask'],
|
||||
gamma=gamma,
|
||||
lam=lam)
|
||||
data.batch['advantages'] = advantages
|
||||
@ -189,27 +189,29 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
|
||||
elif adv_estimator == AdvantageEstimator.GRPO:
|
||||
advantages, returns = core_algos.compute_grpo_outcome_advantage(
|
||||
token_level_rewards=data.batch['token_level_rewards'],
|
||||
eos_mask=data.batch['response_mask'],
|
||||
response_mask=data.batch['response_mask'],
|
||||
index=data.non_tensor_batch['uid'])
|
||||
data.batch['advantages'] = advantages
|
||||
data.batch['returns'] = returns
|
||||
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
|
||||
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
|
||||
token_level_rewards=data.batch['token_level_rewards'], eos_mask=data.batch['response_mask'], gamma=gamma)
|
||||
token_level_rewards=data.batch['token_level_rewards'],
|
||||
response_mask=data.batch['response_mask'],
|
||||
gamma=gamma)
|
||||
data.batch['advantages'] = advantages
|
||||
data.batch['returns'] = returns
|
||||
elif adv_estimator == AdvantageEstimator.REMAX:
|
||||
advantages, returns = core_algos.compute_remax_outcome_advantage(
|
||||
token_level_rewards=data.batch['token_level_rewards'],
|
||||
reward_baselines=data.batch['reward_baselines'],
|
||||
eos_mask=data.batch['response_mask'])
|
||||
response_mask=data.batch['response_mask'])
|
||||
|
||||
data.batch['advantages'] = advantages
|
||||
data.batch['returns'] = returns
|
||||
elif adv_estimator == AdvantageEstimator.RLOO:
|
||||
advantages, returns = core_algos.compute_rloo_outcome_advantage(
|
||||
token_level_rewards=data.batch['token_level_rewards'],
|
||||
eos_mask=data.batch['response_mask'],
|
||||
response_mask=data.batch['response_mask'],
|
||||
index=data.non_tensor_batch['uid'])
|
||||
data.batch['advantages'] = advantages
|
||||
data.batch['returns'] = returns
|
||||
|
@ -170,13 +170,13 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas
|
||||
return output
|
||||
|
||||
|
||||
def vocab_parallel_compute_entropy_loss(logits, eos_mask):
|
||||
def vocab_parallel_compute_entropy_loss(logits, response_mask):
|
||||
"""Compute Categorical entropy loss
|
||||
|
||||
Args:
|
||||
logits: `(torch.Tensor)`
|
||||
shape: (bs, response_length, vocab_size)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
response_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
@ -185,5 +185,5 @@ def vocab_parallel_compute_entropy_loss(logits, eos_mask):
|
||||
"""
|
||||
# compute entropy
|
||||
entropy = vocab_parallel_entropy(logits)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
|
||||
return entropy_loss
|
||||
|
@ -147,24 +147,27 @@ def masked_whiten(values, mask, shift_mean=True):
|
||||
return whitened
|
||||
|
||||
|
||||
def get_eos_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
|
||||
def get_response_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
|
||||
'''
|
||||
end of sentence token can be int or list: 1 or [1, 2]
|
||||
e.g. eos_token=1
|
||||
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
|
||||
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
|
||||
e.g.
|
||||
response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0],
|
||||
[78, 0, 76, 2, 1, 0, 0],
|
||||
[23, 98, 1, 0, 0, 0, 0],
|
||||
[33, 3, 98, 45, 1, 0, 0]])
|
||||
#eos_token=1
|
||||
response_mask: tensor([[1, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0]])
|
||||
#eos_token=[1,2]
|
||||
response_mask: tensor([[1, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0]])
|
||||
'''
|
||||
if isinstance(eos_token, int):
|
||||
eos_token = [eos_token]
|
||||
|
||||
eos_mask = torch.zeros_like(response_id, dtype=torch.bool)
|
||||
for token in eos_token:
|
||||
eos_mask |= response_id.eq(token)
|
||||
|
||||
eos_mask = eos_mask.long()
|
||||
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
|
||||
eos_mask = torch.logical_not(eos_mask).to(dtype)
|
||||
return eos_mask
|
||||
eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int()
|
||||
return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype)
|
||||
|
||||
|
||||
def compute_grad_norm(model: nn.Module):
|
||||
|
@ -296,7 +296,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
advantages=advantages,
|
||||
eos_mask=response_mask,
|
||||
response_mask=response_mask,
|
||||
cliprange=clip_ratio,
|
||||
clip_ratio_c=clip_ratio_c)
|
||||
# compute entropy loss from entropy
|
||||
|
@ -285,13 +285,14 @@ class MegatronPPOActor(BasePPOActor):
|
||||
logits_back = logits.clone()
|
||||
log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
|
||||
logits = logits_back
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
advantages=advantages,
|
||||
eos_mask=response_mask,
|
||||
cliprange=clip_ratio,
|
||||
clip_ratio_c=clip_ratio_c)
|
||||
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask)
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss(
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
advantages=advantages,
|
||||
response_mask=response_mask,
|
||||
cliprange=clip_ratio,
|
||||
clip_ratio_c=clip_ratio_c)
|
||||
entropy_loss = vocab_parallel_compute_entropy_loss(logits, response_mask=response_mask)
|
||||
policy_loss = pg_loss - entropy_loss * entropy_coeff
|
||||
|
||||
metrics = {}
|
||||
|
@ -217,7 +217,7 @@ class DataParallelPPOCritic(BasePPOCritic):
|
||||
returns = data['returns']
|
||||
response_length = responses.size(1)
|
||||
|
||||
eos_mask = attention_mask[:, -response_length - 1:-1]
|
||||
response_mask = attention_mask[:, -response_length - 1:-1]
|
||||
|
||||
vpreds = self._forward_micro_batch(data)
|
||||
|
||||
@ -226,7 +226,7 @@ class DataParallelPPOCritic(BasePPOCritic):
|
||||
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
|
||||
values=values,
|
||||
returns=returns,
|
||||
eos_mask=eos_mask,
|
||||
response_mask=response_mask,
|
||||
cliprange_value=self.config.cliprange_value)
|
||||
if self.config.use_dynamic_bsz:
|
||||
# relative to the dynamic bsz
|
||||
@ -239,7 +239,7 @@ class DataParallelPPOCritic(BasePPOCritic):
|
||||
data = {
|
||||
'critic/vf_loss': vf_loss.detach().item(),
|
||||
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, response_mask).detach().item(),
|
||||
}
|
||||
|
||||
append_to_dict(metrics, data)
|
||||
|
@ -138,7 +138,7 @@ class MegatronPPOCritic(BasePPOCritic):
|
||||
returns = data['returns']
|
||||
response_length = responses.size(1)
|
||||
|
||||
eos_mask = attention_mask[:, -response_length:]
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
|
||||
cliprange_value = self.config.cliprange_value
|
||||
|
||||
@ -148,12 +148,12 @@ class MegatronPPOCritic(BasePPOCritic):
|
||||
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
|
||||
values=values,
|
||||
returns=returns,
|
||||
eos_mask=eos_mask,
|
||||
response_mask=response_mask,
|
||||
cliprange_value=cliprange_value)
|
||||
stats = {
|
||||
'critic/vf_loss': vf_loss.detach().item(),
|
||||
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
|
||||
'critic/vpred_mean': masked_mean(vpreds, response_mask).detach().item(),
|
||||
}
|
||||
|
||||
return vf_loss, stats
|
||||
|
@ -24,7 +24,7 @@ from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask
|
||||
from verl.utils.torch_functional import get_response_mask
|
||||
from .base import BaseRollout
|
||||
|
||||
from transformers import GenerationConfig
|
||||
@ -120,7 +120,9 @@ class HFRollout(BaseRollout):
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
response_attention_mask = get_response_mask(response_id=response,
|
||||
eos_token=eos_token_id,
|
||||
dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
batch = TensorDict(
|
||||
|
@ -33,7 +33,7 @@ from omegaconf import DictConfig
|
||||
from tensordict import TensorDict
|
||||
from verl import DataProto
|
||||
from verl.workers.rollout.base import BaseRollout
|
||||
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length, pad_2d_list_to_length
|
||||
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length, pad_2d_list_to_length
|
||||
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@ -268,7 +268,9 @@ class SGLangRollout(BaseRollout):
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
response_attention_mask = get_response_mask(response_id=response,
|
||||
eos_token=eos_token_id,
|
||||
dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
# all the tp ranks should contain the same data here. data in all ranks are valid
|
||||
|
@ -33,7 +33,7 @@ from tensordict import TensorDict
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
|
||||
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
|
||||
from verl.workers.rollout.base import BaseRollout
|
||||
from verl.workers.rollout.vllm_rollout.vllm_rollout import vLLMRollout
|
||||
from verl.third_party.vllm import LLM, vllm_version
|
||||
@ -192,7 +192,9 @@ class FIREvLLMRollout(vLLMRollout):
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
response_attention_mask = get_response_mask(response_id=response,
|
||||
eos_token=eos_token_id,
|
||||
dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
# all the tp ranks should contain the same data here. data in all ranks are valid
|
||||
|
@ -33,7 +33,7 @@ from tensordict import TensorDict
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
|
||||
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
|
||||
from verl.workers.rollout.base import BaseRollout
|
||||
from verl.third_party.vllm import LLM, vllm_version
|
||||
from verl.third_party.vllm import parallel_state as vllm_ps
|
||||
@ -229,7 +229,9 @@ class vLLMRollout(BaseRollout):
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
response_attention_mask = get_response_mask(response_id=response,
|
||||
eos_token=eos_token_id,
|
||||
dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
# all the tp ranks should contain the same data here. data in all ranks are valid
|
||||
|
@ -34,7 +34,7 @@ from tensordict import TensorDict
|
||||
from torch import nn
|
||||
from typing import Any, Union
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length
|
||||
from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length
|
||||
from verl.workers.rollout.base import BaseRollout
|
||||
from vllm.distributed import parallel_state as vllm_ps
|
||||
from vllm import LLM, SamplingParams
|
||||
@ -269,7 +269,9 @@ class vLLMRollout(BaseRollout):
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
response_attention_mask = get_response_mask(response_id=response,
|
||||
eos_token=eos_token_id,
|
||||
dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
# all the tp ranks should contain the same data here. data in all ranks are valid
|
||||
|
Reference in New Issue
Block a user