fix: misleading eos_mask->response_mask (#878)

https://github.com/volcengine/verl/pull/868#discussion_r2024416560
This commit is contained in:
Lumeng Wu
2025-04-03 13:01:07 +08:00
committed by GitHub
parent 7895c1f472
commit 8cae42dc29
16 changed files with 122 additions and 102 deletions

View File

@ -17,7 +17,7 @@ import verl
import verl.utils.torch_functional as verl_F
def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config):
def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):
# calculate rloo reward on different reward sources, and sum again
def masked_rloo(reward_tensor_original, mask_tensor):
reward_tensor = reward_tensor_original.clone()
@ -44,13 +44,13 @@ def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor,
if 'rm_scores' in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.:
reward_tensor = data.batch['rm_scores']
reward_mask = eos_mask.bool()
reward_mask = response_mask.bool()
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)
if 'acc' in data.batch.keys() and config.algorithm.reward_gt_coef != 0.:
reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool)
reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
prompt_ids = data.batch['prompts']
prompt_length = prompt_ids.shape[-1]
@ -67,25 +67,25 @@ def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor,
final_reward_tensor = sum(reward_tensors)
returns = (final_reward_tensor * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns.clone()
advantages = verl_F.masked_whiten(advantages, eos_mask)
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
def compute_ce_dpo_loss_rm(token_level_scores, acc, eos_mask, beta):
cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta).sigmoid()
def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()
cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
return cur_dpo_loss
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, eos_mask, beta, bon_mode='none'):
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode='none'):
# we always assume that the BoN size equals n_samples
# mode1: use acc as rm
# mode2: use Q as rm
cur_Q = (token_level_scores * eos_mask).sum(dim=1) * beta
cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta
other_Q = torch.zeros_like(cur_Q)
for i in range(token_level_scores.shape[0]):
if acc[i] > 0:
@ -115,11 +115,11 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, eos_mask,
return dpo_loss
def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples):
def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):
dpo_acc = []
for start_id in range(0, token_level_scores.shape[0], n_samples):
cur_scores = (token_level_scores[start_id:start_id + n_samples] *
eos_mask[start_id:start_id + n_samples]).sum(dim=1)
response_mask[start_id:start_id + n_samples]).sum(dim=1)
def get_upper_triangle(tensor_x):
diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
@ -140,5 +140,5 @@ def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples):
return torch.cat(dpo_acc, dim=0).mean()
def compute_dpo_abs_accuracy(token_level_scores, acc, eos_mask, n_samples):
return (torch.sign((token_level_scores * eos_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()
def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):
return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()

View File

@ -277,7 +277,7 @@ class DataParallelPRIMERewardModel:
prompt_ids = data['prompts']
prompt_length = prompt_ids.shape[-1]
eos_mask = attention_mask[:, prompt_length:]
response_mask = attention_mask[:, prompt_length:]
rm_score, q = self._forward_micro_batch(data, prompt_length)
@ -285,14 +285,14 @@ class DataParallelPRIMERewardModel:
q_lst.append(q.detach())
if self.config.model.loss_type == 'ce':
dpo_loss = compute_ce_dpo_loss_rm(q, acc, eos_mask=eos_mask, beta=beta)
dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)
elif self.config.model.loss_type == 'dpo':
# the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update.
dpo_loss = compute_detach_dpo_loss_rm(q,
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
eos_mask=eos_mask,
response_mask=response_mask,
beta=beta)
elif self.config.model.loss_type == 'bon_acc':
# change the original distribution of each sample to BoN distribution, then update reward model
@ -300,7 +300,7 @@ class DataParallelPRIMERewardModel:
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
eos_mask=eos_mask,
response_mask=response_mask,
beta=beta,
bon_mode='bon_acc')
elif self.config.model.loss_type == 'bon_rm':
@ -308,7 +308,7 @@ class DataParallelPRIMERewardModel:
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
eos_mask=eos_mask,
response_mask=response_mask,
beta=beta,
bon_mode='bon_rm')
else:

View File

@ -261,11 +261,11 @@ class PRIMERewardModelWorker(Worker):
rm_scores, q, metrics = self.rm.compute_rm_score(data=data)
prompt_length = data.batch['prompts'].shape[-1]
eos_mask = data.batch['attention_mask'][:, prompt_length:]
response_mask = data.batch['attention_mask'][:, prompt_length:]
acc = data.batch['acc']
dpo_acc = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, eos_mask, n_samples=data.meta_info['n'])
dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
metrics['reward_model/dpo_acc'] = dpo_acc.detach().item()
metrics['reward_model/dpo_acc_abs'] = dpo_acc_abs.detach().item()
@ -299,11 +299,14 @@ class PRIMERewardModelWorker(Worker):
metrics['rm/lr'] = lr
prompt_length = data.batch['prompts'].shape[-1]
eos_mask = data.batch['attention_mask'][:, prompt_length:]
response_mask = data.batch['attention_mask'][:, prompt_length:]
acc = data.batch['acc']
dpo_acc_before = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, eos_mask, n_samples=data.meta_info['n'])
dpo_acc_before = compute_dpo_accuracy(rm_scores,
acc,
response_mask=response_mask,
n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
metrics['reward_model/dpo_acc_before'] = dpo_acc_before.detach().item()
metrics['reward_model/dpo_acc_abs_before'] = dpo_acc_abs.detach().item()

View File

@ -63,7 +63,7 @@ def get_kl_controller(kl_ctrl):
raise NotImplementedError
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor,
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor,
gamma: torch.Tensor, lam: torch.Tensor):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
@ -72,7 +72,7 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
shape: (bs, response_length)
values: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma: `(float)`
discounted factor used in RL
@ -99,13 +99,13 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, eos_mask)
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
response_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
@ -114,7 +114,7 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
@ -145,13 +145,13 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
return scores, scores
def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
response_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
@ -159,7 +159,7 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
@ -190,12 +190,12 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num -
1) - id2mean[index[i]] * response_num / (response_num - 1)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
return scores, scores
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, eos_mask: torch.Tensor,
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor,
gamma: torch.Tensor):
"""
Compute advantage for REINFORCE++.
@ -203,7 +203,7 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
@ -221,16 +221,16 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
running_return = running_return * response_mask[:, t]
advantages = verl_F.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask
advantages = verl_F.masked_whiten(returns, response_mask)
advantages = advantages * response_mask
return advantages, returns
def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor,
eos_mask: torch.Tensor):
response_mask: torch.Tensor):
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
@ -241,7 +241,7 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
@ -254,8 +254,8 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba
scores = token_level_rewards.sum(dim=-1)
with torch.no_grad():
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * response_mask
return advantages, returns
@ -265,7 +265,7 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
return token_level_scores - kl * kl_ratio
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, clip_ratio_c=3.0):
def compute_policy_loss(old_log_prob, log_prob, advantages, response_mask, cliprange, clip_ratio_c=3.0):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
@ -275,7 +275,7 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange,
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange: (float)
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
@ -294,32 +294,33 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange,
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
clip_pg_losses1 = torch.max(pg_losses, pg_losses2)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), response_mask)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), eos_mask)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), response_mask)
# We only apply the dual-clip when the advantage is negative.
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = verl_F.masked_mean(pg_losses, eos_mask)
pg_loss = verl_F.masked_mean(pg_losses, response_mask)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
def compute_entropy_loss(logits, eos_mask):
def compute_entropy_loss(logits, response_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
@ -328,11 +329,11 @@ def compute_entropy_loss(logits, eos_mask):
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
return entropy_loss
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
@ -353,8 +354,8 @@ def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns)**2
vf_losses2 = (vpredclipped - returns)**2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac

View File

@ -181,7 +181,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=data.batch['token_level_rewards'],
values=data.batch['values'],
eos_mask=data.batch['response_mask'],
response_mask=data.batch['response_mask'],
gamma=gamma,
lam=lam)
data.batch['advantages'] = advantages
@ -189,27 +189,29 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
elif adv_estimator == AdvantageEstimator.GRPO:
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
eos_mask=data.batch['response_mask'],
response_mask=data.batch['response_mask'],
index=data.non_tensor_batch['uid'])
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'], eos_mask=data.batch['response_mask'], gamma=gamma)
token_level_rewards=data.batch['token_level_rewards'],
response_mask=data.batch['response_mask'],
gamma=gamma)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.REMAX:
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
reward_baselines=data.batch['reward_baselines'],
eos_mask=data.batch['response_mask'])
response_mask=data.batch['response_mask'])
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.RLOO:
advantages, returns = core_algos.compute_rloo_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
eos_mask=data.batch['response_mask'],
response_mask=data.batch['response_mask'],
index=data.non_tensor_batch['uid'])
data.batch['advantages'] = advantages
data.batch['returns'] = returns

View File

@ -170,13 +170,13 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas
return output
def vocab_parallel_compute_entropy_loss(logits, eos_mask):
def vocab_parallel_compute_entropy_loss(logits, response_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
@ -185,5 +185,5 @@ def vocab_parallel_compute_entropy_loss(logits, eos_mask):
"""
# compute entropy
entropy = vocab_parallel_entropy(logits)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
return entropy_loss

View File

@ -147,24 +147,27 @@ def masked_whiten(values, mask, shift_mean=True):
return whitened
def get_eos_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
def get_response_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
'''
end of sentence token can be int or list: 1 or [1, 2]
e.g. eos_token=1
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
e.g.
response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0],
[78, 0, 76, 2, 1, 0, 0],
[23, 98, 1, 0, 0, 0, 0],
[33, 3, 98, 45, 1, 0, 0]])
#eos_token=1
response_mask: tensor([[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]])
#eos_token=[1,2]
response_mask: tensor([[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]])
'''
if isinstance(eos_token, int):
eos_token = [eos_token]
eos_mask = torch.zeros_like(response_id, dtype=torch.bool)
for token in eos_token:
eos_mask |= response_id.eq(token)
eos_mask = eos_mask.long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask
eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int()
return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype)
def compute_grad_norm(model: nn.Module):

View File

@ -296,7 +296,7 @@ class DataParallelPPOActor(BasePPOActor):
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
response_mask=response_mask,
cliprange=clip_ratio,
clip_ratio_c=clip_ratio_c)
# compute entropy loss from entropy

View File

@ -285,13 +285,14 @@ class MegatronPPOActor(BasePPOActor):
logits_back = logits.clone()
log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
logits = logits_back
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio,
clip_ratio_c=clip_ratio_c)
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
clip_ratio_c=clip_ratio_c)
entropy_loss = vocab_parallel_compute_entropy_loss(logits, response_mask=response_mask)
policy_loss = pg_loss - entropy_loss * entropy_coeff
metrics = {}

View File

@ -217,7 +217,7 @@ class DataParallelPPOCritic(BasePPOCritic):
returns = data['returns']
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1:-1]
response_mask = attention_mask[:, -response_length - 1:-1]
vpreds = self._forward_micro_batch(data)
@ -226,7 +226,7 @@ class DataParallelPPOCritic(BasePPOCritic):
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
response_mask=response_mask,
cliprange_value=self.config.cliprange_value)
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
@ -239,7 +239,7 @@ class DataParallelPPOCritic(BasePPOCritic):
data = {
'critic/vf_loss': vf_loss.detach().item(),
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
'critic/vpred_mean': masked_mean(vpreds, response_mask).detach().item(),
}
append_to_dict(metrics, data)

View File

@ -138,7 +138,7 @@ class MegatronPPOCritic(BasePPOCritic):
returns = data['returns']
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length:]
response_mask = attention_mask[:, -response_length:]
cliprange_value = self.config.cliprange_value
@ -148,12 +148,12 @@ class MegatronPPOCritic(BasePPOCritic):
vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
response_mask=response_mask,
cliprange_value=cliprange_value)
stats = {
'critic/vf_loss': vf_loss.detach().item(),
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
'critic/vpred_mean': masked_mean(vpreds, response_mask).detach().item(),
}
return vf_loss, stats

View File

@ -24,7 +24,7 @@ from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask
from verl.utils.torch_functional import get_response_mask
from .base import BaseRollout
from transformers import GenerationConfig
@ -120,7 +120,9 @@ class HFRollout(BaseRollout):
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
response_attention_mask = get_response_mask(response_id=response,
eos_token=eos_token_id,
dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
batch = TensorDict(

View File

@ -33,7 +33,7 @@ from omegaconf import DictConfig
from tensordict import TensorDict
from verl import DataProto
from verl.workers.rollout.base import BaseRollout
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length, pad_2d_list_to_length
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length, pad_2d_list_to_length
from sglang.srt.entrypoints.verl_engine import VerlEngine
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.sampling.sampling_params import SamplingParams
@ -268,7 +268,9 @@ class SGLangRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
response_attention_mask = get_response_mask(response_id=response,
eos_token=eos_token_id,
dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid

View File

@ -33,7 +33,7 @@ from tensordict import TensorDict
from torch import nn
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.vllm_rollout.vllm_rollout import vLLMRollout
from verl.third_party.vllm import LLM, vllm_version
@ -192,7 +192,9 @@ class FIREvLLMRollout(vLLMRollout):
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
response_attention_mask = get_response_mask(response_id=response,
eos_token=eos_token_id,
dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid

View File

@ -33,7 +33,7 @@ from tensordict import TensorDict
from torch import nn
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
@ -229,7 +229,9 @@ class vLLMRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
response_attention_mask = get_response_mask(response_id=response,
eos_token=eos_token_id,
dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid

View File

@ -34,7 +34,7 @@ from tensordict import TensorDict
from torch import nn
from typing import Any, Union
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length
from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length
from verl.workers.rollout.base import BaseRollout
from vllm.distributed import parallel_state as vllm_ps
from vllm import LLM, SamplingParams
@ -269,7 +269,9 @@ class vLLMRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
response_attention_mask = get_response_mask(response_id=response,
eos_token=eos_token_id,
dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid