[misc] fix: sft SFT E2E CI test failure due to megatron engine (#3786)

This commit is contained in:
Houmin Wei
2025-10-17 06:27:39 +08:00
committed by GitHub
parent acfcf98ed0
commit 65b8bf1bc0
5 changed files with 12 additions and 3 deletions

View File

@ -1,3 +1,5 @@
#!/usr/bin/env bash
set -xeuo pipefail
rm -rf ~/verl/test/log
mkdir -p ~/verl/test/log

View File

@ -183,7 +183,6 @@ def gptmodel_forward_no_padding(
k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items()
}
output_dict = logits_processor(output_orig, **args)
# print(f'gptmodel_forward_no_padding: {output_dict=}')
output = {
k: postprocess_packed_seqs_no_padding(
v, packed_seq_params, input_ids, batch_size, post_process=post_process

View File

@ -208,6 +208,7 @@ def preprocess_packed_seqs_no_padding(
seqlen = seqlens_in_batch_cpu[i]
start_idx = cu_seqlens_padded_cpu[i]
input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i]
continue
seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
seqlen = seqlen_padded_i // cp_size

View File

@ -974,7 +974,8 @@ class FSDPEngineWithValueHead(FSDPEngineWithLMHead):
else:
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz, 1)
# FIXME(houmin): confirm why should we squeeze here
# critic model arch is like Qwen3ForTokenClassfication and num_labels=1
# so we squeeze the last dimension here to get the value for each token
values_rmpad = values_rmpad.squeeze(-1)
# gather output if sp > 1

View File

@ -595,7 +595,13 @@ class MegatronEngineWithLMHead(MegatronEngine):
else:
logits_bak = logits
# FIXME(houmin): maybe shift label in another place
# Create the final labels for next-token prediction.
# The `label` tensor starts as a clone of `input_ids`. `torch.roll` is not applied
# earlier because `input_ids` is a nested tensor, which is incompatible with the operation.
# The `preprocess_packed_seqs_no_padding` function unnests and flattens the tensor
# into `input_ids_rmpad` (shape: [1, total_seqlen]).
# Now, on this simple, unpadded tensor, we can perform the standard left shift
# to align the target token `t+1` with the prediction for token `t`.
label = torch.roll(label, shifts=-1, dims=1)
log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)