[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 3 (#3600)

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

This PR continues the work started in PR #3567 by deprecating and
removing the left_right padding mode
1. Implement no-padding mode for Megatron engine using nested tensors in
sft trainer
2. Deprecating left_right padding mode for FSDP/Megatron engine
3. Introduces a transformation layer within Actor/Critic workers, see
more
[here](https://github.com/volcengine/verl/blob/main/docs/workers/model_engine.rst)
- **Input Format**:​​ Actor/Critic workers continue to receive data in
left_rightpadded format.
- ​​**Transformation**:​​ This layer dynamically converts
left_rightpadded data into the no-padding format using nested tensors.
- **Engine Format**:​​ FSDP and Megatron engines now operate exclusively
using the no-padding data format by default.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
Houmin Wei
2025-10-13 08:18:09 +08:00
committed by GitHub
parent 8cc9e3af67
commit 7ddb9b29f0
16 changed files with 482 additions and 262 deletions

View File

@ -29,7 +29,7 @@ PP_SIZE=${PP_SIZE:-1}
VPP_SIZE=${VPP_SIZE:-null}
CP_SIZE=${CP_SIZE:-1}
PAD_MODE=${PAD_MODE:-left_right}
PAD_MODE=${PAD_MODE:-no_padding}
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
@ -80,8 +80,6 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
data.train_files="${TRAIN_FILES}" \
data.val_files="${VAL_FILES}" \
data.train_batch_size=256 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.pad_mode=${PAD_MODE} \
data.truncation=error \
data.use_dynamic_bsz=True \

View File

@ -9,15 +9,6 @@ echo "run with single gpu as golden"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test with fsdp 1
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
@ -27,18 +18,14 @@ BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_pa
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test use_remove_padding and pad_mode left_right/no_padding
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right use_remove_padding False"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test use_remove_padding and pad_mode no_padding
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test with fsdp 2
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode left_right"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode no_padding"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
@ -50,6 +37,7 @@ BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/spe
# test with megatron
echo "run with tp1 pp1 cp1 num_gpus1"
BACKEND=megatron TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 NUM_GPUS=1 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with tp2 pp2 vpp2 cp1 num_gpus8"
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=1 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh

View File

@ -177,28 +177,26 @@ def test_multiturn_sft_dataset():
assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding"
assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding"
# test left right padding
# test no-padding
config = {
"max_length": 512,
"truncation": "error",
"multiturn": {"messages_key": "messages"},
"pad_mode": "left_right",
"max_prompt_length": 64,
"max_response_length": 64,
"pad_mode": "no_padding",
}
dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
item0 = dataset[0]
# make sure all the input_ids with attention_mask == 0 are all padding
assert torch.all(item0["input_ids"][item0["attention_mask"] == 0] == tokenizer.pad_token_id)
# Verify that the output contains expected keys for no-padding mode
required_keys = ["input_ids", "position_ids", "loss_mask"]
for key in required_keys:
assert key in item0, f"Missing key {key} in no-padding mode dataset item"
assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key} in no-padding mode"
# make sure assistant_text matches with expected
assistant_text = tokenizer.decode(item0["responses"][item0["response_mask"] == 1])
assistant_text = tokenizer.decode(item0["input_ids"][item0["loss_mask"] == 1])
assert assistant_text == "2+2 equals 4.<|im_end|>\n4+4 equals 8.<|im_end|>\n"
# make sure responses are part of input_ids
assert torch.all(item0["input_ids"][-item0["responses"].shape[0] :] == item0["responses"])
print("All tests passed!")
print("Starting test...")

View File

@ -16,6 +16,7 @@
from .registry import (
get_mcore_forward_fn,
get_mcore_forward_fused_fn,
get_mcore_forward_no_padding_fn,
get_mcore_weight_converter,
hf_to_mcore_config,
init_mcore_model,
@ -27,4 +28,5 @@ __all__ = [
"get_mcore_forward_fn",
"get_mcore_weight_converter",
"get_mcore_forward_fused_fn",
"get_mcore_forward_no_padding_fn",
]

View File

@ -16,7 +16,14 @@
from verl.utils.megatron_utils import unwrap_model
from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding
from .util import (
postprocess_packed_seqs,
postprocess_packed_seqs_no_padding,
preprocess_packed_seqs,
preprocess_packed_seqs_no_padding,
recover_left_padding,
remove_left_padding,
)
def gptmodel_forward(
@ -146,3 +153,55 @@ def gptmodel_forward_qwen2_5_vl(
if value_model and post_process:
output = output[..., 0]
return output
def gptmodel_forward_no_padding(
model,
input_ids,
value_model=False,
pack_seqs=True,
logits_processor=None,
logits_processor_args: dict = None,
**kwargs,
):
"""Default forward pass for GPT models with optional sequence packing."""
pre_process = unwrap_model(model).pre_process
post_process = unwrap_model(model).post_process
if pack_seqs:
batch_size = input_ids.shape[0]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process)
input_ids_rmpad = input_ids_rmpad.contiguous()
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=None,
packed_seq_params=packed_seq_params,
)
if post_process and logits_processor is not None:
args = {
k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items()
}
output_dict = logits_processor(output_orig, **args)
# print(f'gptmodel_forward_no_padding: {output_dict=}')
output = {
k: postprocess_packed_seqs_no_padding(
v, packed_seq_params, input_ids, batch_size, post_process=post_process
)
for k, v in output_dict.items()
}
else:
output = postprocess_packed_seqs_no_padding(
output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process
)
else:
raise NotImplementedError("gptmodel_forward_no_padding only supports packed sequences")
if value_model and post_process:
# output = output[..., 0]
# while using nested tensor, the advanced indexing operation above will result in an error at backward, i.e.
# ValueError: NestedTensor _nested_select_backward_default(grad_output: t, self: jt_all, dim: any, index: any)
# so we use `squeeze` to remove the last dimension
output = output.squeeze(-1)
return output

View File

@ -33,8 +33,15 @@ from .config_converter import (
hf_to_mcore_config_qwen2moe,
hf_to_mcore_config_qwen3moe,
)
from .model_forward import gptmodel_forward, gptmodel_forward_qwen2_5_vl
from .model_forward_fused import fused_forward_gptmodel, fused_forward_qwen2_5_vl
from .model_forward import (
gptmodel_forward,
gptmodel_forward_no_padding,
gptmodel_forward_qwen2_5_vl,
)
from .model_forward_fused import (
fused_forward_gptmodel,
fused_forward_qwen2_5_vl,
)
from .model_initializer import (
BaseModelInitializer,
DeepseekV3Model,
@ -116,6 +123,23 @@ MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
}
# Registry for model forward functions
MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.LLAMA: gptmodel_forward_no_padding,
SupportedModel.QWEN2: gptmodel_forward_no_padding,
SupportedModel.QWEN2_MOE: gptmodel_forward_no_padding,
SupportedModel.MIXTRAL: gptmodel_forward_no_padding,
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,
SupportedModel.LLAMA4: gptmodel_forward_no_padding,
SupportedModel.QWEN3: gptmodel_forward_no_padding,
SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,
# SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
SupportedModel.GLM4_MOE: gptmodel_forward_no_padding,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding,
}
# Registry for model forward functions
MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.LLAMA: fused_forward_gptmodel,
@ -220,6 +244,15 @@ def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
return MODEL_FORWARD_REGISTRY[model]
def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
return MODEL_FORWARD_NOPAD_REGISTRY[model]
def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:
"""
Get the forward function for given model architecture.

View File

@ -162,6 +162,151 @@ def postprocess_packed_seqs(
return output_new
def preprocess_packed_seqs_no_padding(
input_ids: torch.Tensor, pre_process: bool = True
) -> tuple[torch.Tensor, PackedSeqParams]:
"""
Preprocess packed sequences
CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1
gets second and second last chunks, and so on), this is for load balancing with causal masking.
See https://github.com/NVIDIA/TransformerEngine/issues/1368
"""
batch_size = input_ids.shape[0]
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
seqlens_in_batch = input_ids.offsets().diff()
pad_size = (align_size - seqlens_in_batch % align_size) % align_size
seqlens_in_batch_padded = seqlens_in_batch + pad_size
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
# ----------------------------------------------------------------------------
# Move the index information needed in the subsequent loop to the CPU at once,
# to avoid frequent .item() calls in the loop that cause D2H synchronization
# ----------------------------------------------------------------------------
seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths
seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding
cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding)
# Pure Python int calculation to avoid further synchronization
max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)
shape = list(input_ids.shape[1:])
shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size
if pre_process:
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
for i in range(batch_size):
# Use Python int, so no GPU→CPU sync in the loop
if cp_size <= 1:
seqlen = seqlens_in_batch_cpu[i]
start_idx = cu_seqlens_padded_cpu[i]
input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i]
seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
seqlen = seqlen_padded_i // cp_size
half_seqlen = seqlen // 2
start_idx = cu_seqlens_padded_cpu[i] // cp_size
# split to 2 chunks
d = input_ids[i]
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
]
remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)
remain_end = seqlen_padded_i - half_seqlen * cp_rank
remain_end = min(remain_end, d.shape[0])
remain_len = remain_end - remain_start
if remain_len > 0:
input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[
remain_start:remain_end
]
packed_seq_params = PackedSeqParams(
qkv_format="thd",
cu_seqlens_q=cu_seqlens_padded,
max_seqlen_q=max_seqlen_in_batch,
cu_seqlens_kv=cu_seqlens_padded,
max_seqlen_kv=max_seqlen_in_batch,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
)
if pre_process:
return input_ids_rmpad.unsqueeze(0), packed_seq_params
else:
return input_ids, packed_seq_params
def postprocess_packed_seqs_no_padding(
output: torch.Tensor,
packed_seq_params: PackedSeqParams,
input_ids: torch.Tensor,
batch_size: int,
post_process: bool = True,
) -> torch.Tensor:
"""
Postprocess packed sequences
"""
if not post_process:
return output
# -------------------------------------------------------------------------
# Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance,
# to avoid a large number of .item() calls in the loop
# -------------------------------------------------------------------------
cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()
# The reason why we use input_ids.offsets() instead of packed_seq_params.cu_seqlens_q.diff()
# is that the latter one is the padded length, while the former one is the original length.
cu_seqlens = input_ids.offsets()
seq_lens_cpu: list[int] = cu_seqlens.diff().tolist()
output_new = []
cp_size = mpu.get_context_parallel_world_size()
# all gather output across context parallel group
if cp_size > 1:
# output shape: [1, packed_len, hidden_dim]
# need to gather across cp group and concatenate in sequence dimension
output_list = [torch.empty_like(output) for _ in range(cp_size)]
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
output_list[mpu.get_context_parallel_rank()] = output
else:
output_list = [output]
for i in range(batch_size):
if cp_size <= 1:
s = seq_lens_cpu[i]
start_idx = cu_padded_cpu[i]
output_new.append(output[0][start_idx : start_idx + s])
continue
s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size
half_seqlen = s_len_padded_chunk // 2
s_len = seq_lens_cpu[i]
s_len_padded = s_len_padded_chunk * cp_size
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
for j in range(cp_size):
o = output_list[j][0]
# split to 2 chunks
packed_start_idx = cu_padded_cpu[i] // cp_size
o0, o1 = (
o[packed_start_idx : packed_start_idx + half_seqlen],
o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
)
tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0
tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1
output_new.append(tmp[:s_len])
output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged)
return output_new_tensor
def remove_left_padding(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,

View File

@ -23,12 +23,9 @@ data:
messages_key: messages # Key for messages list in multi-turn mode
tools_key: tools # Key for tools list in multi-turn mode
enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode
pad_mode: left_right
pad_mode: no_padding
# for right padding
max_length: 1024
# for left right padding
max_prompt_length: 512
max_response_length: 512
truncation: error
balance_dp_token: False # to be implement
custom_cls:

View File

@ -29,8 +29,6 @@ from transformers import PreTrainedTokenizer
from verl.utils import hf_tokenizer
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.torch_functional import pad_sequence_to_length, postprocess_data
def convert_nested_value_to_list_recursive(data_item):
@ -55,15 +53,12 @@ class MultiTurnSFTDataset(Dataset):
# Set defaults and extract parameters from config if provided
config = config or {}
self.pad_mode = config.get("pad_mode", "right")
assert self.pad_mode in ["right", "left_right", "no_padding"], (
f"Expect pad_mode to be 'right', 'left_right' or 'no_padding'. Got {self.pad_mode}"
assert self.pad_mode in ["right", "no_padding"], (
f"Expect pad_mode to be 'right' or 'no_padding'. Got {self.pad_mode}"
)
self.truncation = config.get("truncation", "error")
# for right padding
self.max_length = config.get("max_length", 1024)
# for left right paddding to be consistent with RL
self.max_prompt_length = config.get("max_prompt_length", 512)
self.max_response_length = config.get("max_response_length", 512)
# Get messages_key from the new multiturn config structure
multiturn_config = config.get("multiturn", {})
self.messages_key = multiturn_config.get("messages_key", "messages")
@ -320,10 +315,8 @@ class MultiTurnSFTDataset(Dataset):
if messages[0]["role"] == "system":
assert messages[1]["role"] == "user"
assert messages[2]["role"] == "assistant"
prompt_message_length = 2
elif messages[0]["role"] == "user":
assert messages[1]["role"] == "assistant"
prompt_message_length = 1
else:
raise ValueError(f"Unknown role: {messages[0]['role']}")
@ -365,68 +358,6 @@ class MultiTurnSFTDataset(Dataset):
"position_ids": position_ids,
"loss_mask": loss_mask,
}
elif self.pad_mode == DatasetPadMode.LEFT_RIGHT:
assert self.truncation == "error", "Only support error truncation for left_right pad mode"
prompt_str = self.tokenizer.apply_chat_template(
messages[:prompt_message_length],
tools=tools,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
**self.apply_chat_template_kwargs,
)
prompt_ids = self.tokenizer.encode(prompt_str, add_special_tokens=False)
prompt_length = len(prompt_ids)
prompt_ids = input_ids[:prompt_length].unsqueeze(0)
prompt_attention_mask = attention_mask[:prompt_length].unsqueeze(0)
prompt_loss_mask = loss_mask[:prompt_length].unsqueeze(0)
response_ids = input_ids[prompt_length:].unsqueeze(0)
response_attention_mask = attention_mask[prompt_length:].unsqueeze(0)
response_loss_mask = loss_mask[prompt_length:].unsqueeze(0)
assert prompt_loss_mask.sum().item() == 0
prompt_ids, prompt_attention_mask = postprocess_data(
input_ids=prompt_ids,
attention_mask=prompt_attention_mask,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
response_ids, response_attention_mask = postprocess_data(
input_ids=response_ids,
attention_mask=response_attention_mask,
max_length=self.max_response_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=False,
truncation=self.truncation,
)
response_loss_mask = pad_sequence_to_length(
response_loss_mask, max_seq_len=self.max_response_length, pad_token_id=0, left_pad=False
)
prompt_ids = prompt_ids[0]
prompt_attention_mask = prompt_attention_mask[0]
response_ids = response_ids[0]
response_attention_mask = response_attention_mask[0]
response_loss_mask = response_loss_mask[0]
assert response_attention_mask[0].item() == 1
assert response_loss_mask[0].item() == 1
input_ids = torch.cat((prompt_ids, response_ids), dim=0)
attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=0)
position_ids = compute_position_id_with_mask(attention_mask)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"responses": response_ids,
"response_mask": response_loss_mask,
}
elif self.pad_mode == DatasetPadMode.NO_PADDING:
# truncate input_ids if it is longer than max_length
if len(input_ids) > self.max_length:
@ -440,3 +371,5 @@ class MultiTurnSFTDataset(Dataset):
"position_ids": position_ids,
"loss_mask": loss_mask,
}
else:
raise ValueError(f"Unknown pad mode {self.pad_mode}")

View File

@ -35,7 +35,6 @@ from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.debug import log_gpu_memory_usage
@ -703,10 +702,12 @@ class EngineTrainModeCtx:
class FSDPEngineWithLMHead(FSDPEngine):
def prepare_model_inputs(self, micro_batch: TensorDict):
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
temperature = micro_batch["temperature"]
assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported"
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch.keys():
from verl.utils.model import extract_multi_modal_inputs
@ -726,32 +727,6 @@ class FSDPEngineWithLMHead(FSDPEngine):
if pad_mode == DatasetPadMode.NO_PADDING:
input_ids_rmpad = input_ids.values().unsqueeze(0) # (1, total_nnz)
position_ids_rmpad = position_ids.values().unsqueeze(0) # (1, total_nnz)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
attention_mask = micro_batch["attention_mask"]
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
output_args["indices"] = indices
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
if "image_bound" in multi_modal_inputs:
from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo
multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(
input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs
)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
@ -821,13 +796,6 @@ class FSDPEngineWithLMHead(FSDPEngine):
attention_mask, padding=0, output_size=(batch_size, max_seq_len)
)
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
attention_mask = micro_batch["attention_mask"]
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
@ -848,7 +816,7 @@ class FSDPEngineWithLMHead(FSDPEngine):
def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
temperature = micro_batch["temperature"]
calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False)
@ -906,33 +874,10 @@ class FSDPEngineWithLMHead(FSDPEngine):
if pad_mode == DatasetPadMode.NO_PADDING:
cu_seqlens = input_ids.offsets()
# (bsz, j1) for each sample, is the length of each sample: [real_prompt length + real_response length]
# (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length]
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
if calculate_entropy:
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
indices = output_args["indices"]
response_length = micro_batch["responses"].size(-1)
batch_size, seqlen = input_ids.shape
full_log_probs = pad_input(
hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen,
)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
# pad back to (bsz, seqlen)
if calculate_entropy:
full_entropy = pad_input(
hidden_states=entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen,
)
# only return response part:
if calculate_entropy:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
@ -960,17 +905,12 @@ class FSDPEngineWithLMHead(FSDPEngine):
logits_rmpad = torch.cat([t for t in logits.unbind()])
input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"]
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
# (bsz, j1) for each sample, length of each sample: [real_prompt_length + real_response_length]
# (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length]
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
if calculate_entropy:
entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged)
entropy_rmpad = torch.cat([t for t in entropy.unbind()])
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
log_probs = logprobs_from_logits(logits, micro_batch["responses"])
if calculate_entropy:
entropy = entropy[:, -response_length - 1 : -1] # (bsz, response_length)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
@ -1022,7 +962,7 @@ class FSDPEngineWithValueHead(FSDPEngineWithLMHead):
def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
response_length = micro_batch["responses"].size(-1)
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
if use_remove_padding:
input_ids = micro_batch["input_ids"]
@ -1033,24 +973,38 @@ class FSDPEngineWithValueHead(FSDPEngineWithLMHead):
values_rmpad = output[2].squeeze(0).unsqueeze(-1)
else:
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
indices = output_args["indices"]
values_rmpad = values_rmpad.squeeze(0) # (total_nnz, 1)
# FIXME(houmin): confirm why should we squeeze here
values_rmpad = values_rmpad.squeeze(-1)
# gather output if sp > 1
if self.use_ulysses_sp:
pad_size = output_args["pad_size"]
values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# pad it back
values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
values = values[:, -response_length - 1 : -1]
if pad_mode == DatasetPadMode.NO_PADDING:
cu_seqlens = input_ids.offsets()
# (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length]
values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
else:
if hasattr(self.module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
values = output[2]
else:
values = output.logits
values = values[:, -response_length - 1 : -1].squeeze(-1)
if pad_mode == DatasetPadMode.NO_PADDING:
cu_seqlens = input_ids.offsets()
seq_lengths = cu_seqlens.diff()
starts = torch.zeros_like(seq_lengths, dtype=torch.int64)
values = torch.nested.narrow(values, 1, starts, seq_lengths, layout=torch.jagged)
values_rmpad = torch.cat([t for t in values.unbind()])
# (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length]
values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
return {"values": values}

View File

@ -28,6 +28,7 @@ from verl.models.mcore import get_mcore_weight_converter
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.device import get_device_id, get_device_name
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
@ -512,11 +513,10 @@ class MegatronEngineWithLMHead(MegatronEngine):
batch = batch.to(get_device_id())
batch = batch.contiguous()
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"].to(bool)
loss_mask = batch["loss_mask"].to(bool)
position_ids = batch["position_ids"]
# process vlm inputs
batch["attention_mask"] = batch["attention_mask"].to(bool)
has_multi_modal_inputs = "multi_modal_inputs" in batch.keys()
if has_multi_modal_inputs:
batch["multi_modal_inputs"] = batch["multi_modal_inputs"]
@ -538,20 +538,18 @@ class MegatronEngineWithLMHead(MegatronEngine):
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
"multi_modal_inputs": multi_modal_inputs,
}
def prepare_model_outputs(self, output: dict, data: TensorDict):
calculate_entropy = tu.get_non_tensor_data(data, key="calculate_entropy", default=False)
responses = data["responses"]
response_length = responses.size(1)
log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous()
log_prob = output["log_probs"]
model_output = {"log_probs": log_prob}
if calculate_entropy:
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
entropy = output["entropy"]
model_output["entropy"] = entropy
return model_output
@ -561,74 +559,58 @@ class MegatronEngineWithLMHead(MegatronEngine):
batch = batch.to(get_device_id())
use_fused_kernels = tu.get_non_tensor_data(batch, key="use_fused_kernels", default=False)
calculate_entropy = tu.get_non_tensor_data(batch, key="calculate_entropy", default=False)
pad_mode = tu.get_non_tensor_data(batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
temperature = batch["temperature"]
model_inputs = self.prepare_model_inputs(batch)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
position_ids = model_inputs["position_ids"]
multi_modal_inputs = model_inputs["multi_modal_inputs"]
responses = batch["responses"]
response_length = responses.size(1)
label = position_ids.clone()
label[:, -response_length - 1 : -1] = responses
label_mask = attention_mask.clone()
label_mask[:, : -response_length - 1] = False
label_mask[:, -1] = False
if pad_mode == DatasetPadMode.NO_PADDING:
label = input_ids.clone()
else:
raise NotImplementedError(f"Pad mode {pad_mode} is not supported for megatron engine")
from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn
from verl.models.mcore import get_mcore_forward_no_padding_fn
if use_fused_kernels:
forward_fn = get_mcore_forward_fused_fn(self.model_config.hf_config)
# return dict of [logits, entropy]
output = forward_fn(
model,
input_ids,
position_ids,
attention_mask,
sequence_parallel=self.tf_config.sequence_parallel,
multi_modal_inputs=multi_modal_inputs,
labels=label,
labels_mask=label_mask,
temperature=temperature,
)
else:
forward_fn = get_mcore_forward_fn(self.model_config.hf_config)
raise NotImplementedError("Fused kernels are not supported for megatron engine")
def logits_processor(logits, label, label_mask):
assert logits.shape[:2] == label.shape[:2]
assert label.shape == label_mask.shape
logits.div_(temperature)
ret = {}
if calculate_entropy:
logits_bak = logits.clone()
if torch.distributed.get_rank() == 0:
logger.warning_once(
"For memory-efficient computation, enable fused kernels via "
"`actor_rollout_ref.model.use_fused_kernels=True`. "
"The current `clone()` operation ensures correctness but increases memory usage."
)
entropy = vocab_parallel_entropy(logits)
ret["entropy"] = entropy
else:
logits_bak = logits
log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)
log_probs = log_probs.masked_fill(~label_mask, 0.0)
ret["log_probs"] = log_probs
return ret
forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)
logits_processor_args = {"label": label, "label_mask": label_mask}
output = forward_fn(
model,
input_ids,
attention_mask,
position_ids,
sequence_parallel=self.tf_config.sequence_parallel,
multi_modal_inputs=multi_modal_inputs,
logits_processor=logits_processor,
logits_processor_args=logits_processor_args,
)
def logits_processor(logits, label):
assert logits.shape[:2] == label.shape[:2]
logits.div_(temperature)
ret = {}
if calculate_entropy:
logits_bak = logits.clone()
if torch.distributed.get_rank() == 0:
logger.warning_once(
"For memory-efficient computation, enable fused kernels via "
"`actor_rollout_ref.model.use_fused_kernels=True`. "
"The current `clone()` operation ensures correctness but increases memory usage."
)
entropy = vocab_parallel_entropy(logits)
ret["entropy"] = entropy
else:
logits_bak = logits
# FIXME(houmin): maybe shift label in another place
label = torch.roll(label, shifts=-1, dims=1)
log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)
ret["log_probs"] = log_probs
return ret
logits_processor_args = {"label": label}
output = forward_fn(
model,
input_ids,
multi_modal_inputs=multi_modal_inputs,
logits_processor=logits_processor,
logits_processor_args=logits_processor_args,
)
return output, partial(postprocess_micro_batch_func, data=batch)
@ -671,19 +653,15 @@ class MegatronEngineWithValueHead(MegatronEngineWithLMHead):
batch = batch.to(get_device_id())
model_inputs = self.prepare_model_inputs(batch)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
position_ids = model_inputs["position_ids"]
multi_modal_inputs = model_inputs["multi_modal_inputs"]
from verl.models.mcore import get_mcore_forward_fn
from verl.models.mcore import get_mcore_forward_no_padding_fn
forward_fn = get_mcore_forward_fn(self.model_config.hf_config)
forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)
output = forward_fn(
model,
input_ids,
attention_mask,
position_ids,
sequence_parallel=self.tf_config.sequence_parallel,
multi_modal_inputs=multi_modal_inputs,
value_model=True,
@ -692,7 +670,4 @@ class MegatronEngineWithValueHead(MegatronEngineWithLMHead):
return output, partial(postprocess_micro_batch_func, data=batch)
def prepare_model_outputs(self, output: dict | torch.Tensor, data: TensorDict):
responses = data["responses"]
response_length = responses.size(1)
output = output[:, -response_length - 1 : -1].contiguous()
return {"values": output}

View File

@ -66,7 +66,8 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
"""
use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True)
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.NO_PADDING)
assert pad_mode == DatasetPadMode.NO_PADDING, "postprocess_batch_func only support NO_PADDING pad_mode"
# losses_reduced is a list of dict containing outputs for each micro-batch
# reorder entropy and outputs. Return None for other pp ranks
@ -92,8 +93,6 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
if pad_mode == DatasetPadMode.NO_PADDING:
tensors = [tensor for nt in model_output[key] for tensor in nt.unbind()]
model_output[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
model_output[key] = torch.cat(model_output[key], dim=0)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")

View File

@ -33,6 +33,7 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension
from verl.utils.py_functional import append_to_dict
from verl.workers.config import ActorConfig
from verl.workers.roles.utils.losses import ppo_loss
from verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@ -116,16 +117,23 @@ class ActorWorker(Worker, DistProfilerExtension):
with self.engine.eval_mode():
# TODO: make worker API to accept TensorDict as well
data = data.to_tensordict()
data = left_right_2_no_padding(data)
output = self.engine.infer_batch(data)
if self.engine.is_mp_src_rank_with_outputs():
output = output["model_output"]
log_probs = output["log_probs"]
log_probs = no_padding_2_padding(log_probs, data) # (bsz, response_length)
entropy = output["entropy"]
if entropy is not None:
entropy = no_padding_2_padding(entropy, data) # (bsz, response_length)
# in megatron, only last pp contains valid data and returned to the single controller
output = DataProto.from_dict(
tensors={"old_log_probs": output["log_probs"].float(), "entropy": output["entropy"].float()},
tensors={"old_log_probs": log_probs.float(), "entropy": entropy.float()},
)
output = output.to("cpu")
return output
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@ -155,6 +163,7 @@ class ActorWorker(Worker, DistProfilerExtension):
mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size
# TODO: make worker API to accept TensorDict as well
mini_batch = mini_batch.to_tensordict()
mini_batch = left_right_2_no_padding(mini_batch)
output = self.engine.train_batch(mini_batch, self.loss_fn)
mini_batch_metrics = output.get("metrics", {})
append_to_dict(metrics, mini_batch_metrics, prefix="actor/")

View File

@ -35,6 +35,7 @@ from verl.utils.profiler import DistProfiler, DistProfilerExtension
from verl.utils.py_functional import append_to_dict
from verl.workers.config import CriticConfig
from verl.workers.roles.utils.losses import value_loss
from verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@ -140,13 +141,17 @@ class CriticWorker(Worker, DistProfilerExtension):
with self.engine.eval_mode():
# TODO: make worker API to accept TensorDict as well
data = data.to_tensordict()
data = left_right_2_no_padding(data)
output = self.engine.infer_batch(data)
if self.engine.is_mp_src_rank_with_outputs():
# in megatron, only last pp contains valid data and returned to the single controller
output = output["model_output"]
values = output["values"]
values = no_padding_2_padding(values, data) # (bsz, response_length)
output = DataProto.from_dict(
tensors={"values": output["values"].float()},
tensors={"values": values.float()},
)
output = output.to("cpu")
@ -177,6 +182,7 @@ class CriticWorker(Worker, DistProfilerExtension):
mini_batch.meta_info["global_batch_size"] = self.config.ppo_mini_batch_size
# TODO: make worker API to accept TensorDict as well
mini_batch = mini_batch.to_tensordict()
mini_batch = left_right_2_no_padding(mini_batch)
output = self.engine.train_batch(mini_batch, self.loss_fn)
mini_batch_metrics = output.get("metrics", {})
append_to_dict(metrics, mini_batch_metrics, prefix="critic/")

View File

@ -21,10 +21,11 @@ from verl.utils import tensordict_utils as tu
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.torch_functional import masked_mean
from verl.workers.config import ActorConfig, CriticConfig
from verl.workers.roles.utils.padding import no_padding_2_padding
def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.NO_PADDING)
log_prob = model_output["log_probs"]
@ -52,6 +53,10 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
log_prob = model_output["log_probs"]
entropy = model_output.get("entropy", None)
log_prob = no_padding_2_padding(log_prob, data) # (bsz, response_length)
if entropy is not None:
entropy = no_padding_2_padding(entropy, data) # (bsz, response_length)
metrics = {}
response_mask = data["response_mask"].to(bool)
@ -105,7 +110,7 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
def value_loss(config: CriticConfig, model_output, data: TensorDict, dp_group=None):
vpreds = model_output["values"]
values = data["values"]
vpreds = no_padding_2_padding(vpreds, data) # (bsz, response_length)
values = data["values"]
returns = data["returns"]

View File

@ -0,0 +1,119 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from tensordict import TensorDict
from verl.utils import tensordict_utils as tu
from verl.utils.device import (
is_cuda_available,
is_npu_available,
)
if is_cuda_available:
from flash_attn.bert_padding import pad_input, unpad_input
elif is_npu_available:
from transformers.integrations.npu_flash_attention import pad_input, unpad_input
def left_right_2_no_padding(data: TensorDict) -> TensorDict:
"""
Convert TensorDict from left-right padding to no-padding format.
Args:
data: TensorDict with "input_ids", "attention_mask", "response_mask", "position_ids"
Returns:
data: TensorDict with
- Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids"
- NonTensorData includes "max_seq_len", "max_response_len", "indices"
Note:
1. the return input_ids/position_ids/loss_mask are nested tensor.
2. we will remove "attention_mask", "response" in the return data, but "response_mask" is kept.
"""
assert "input_ids" in data, "input_ids is required in left-right padding data"
assert "attention_mask" in data, "attention_mask is required in left-right padding data"
assert "response_mask" in data, "response_mask is required in left-right padding data"
assert "position_ids" in data, "position_ids is required in left-right padding data"
input_ids = data.pop("input_ids")
attention_mask = data.pop("attention_mask")
response_mask = data["response_mask"]
if "responses" in data:
_ = data.pop("responses")
max_seq_len, max_response_len = input_ids.shape[1], response_mask.shape[1]
tu.assign_non_tensor_data(data, "max_seq_len", max_seq_len)
tu.assign_non_tensor_data(data, "max_response_len", max_response_len)
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
tu.assign_non_tensor_data(data, "indices", indices)
input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens)
seq_lens = cu_seqlens.diff().tolist()
response_lens = response_mask.sum(dim=1).tolist()
position_ids_list = []
loss_mask_list = []
for seq_len, response_len in zip(seq_lens, response_lens, strict=False):
position_ids_list.append(torch.arange(seq_len, device=input_ids.device))
loss_mask = torch.zeros(seq_len, dtype=torch.bool, device=input_ids.device)
assert seq_len >= response_len, f"{seq_len=} is less than {response_len=}"
loss_mask[-response_len:] = 1
loss_mask_list.append(loss_mask)
position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged)
loss_mask_nested = torch.nested.as_nested_tensor(loss_mask_list, layout=torch.jagged)
data["input_ids"] = input_ids_nested
data["position_ids"] = position_ids_nested
data["loss_mask"] = loss_mask_nested
return data
def no_padding_2_padding(nested_tensor: torch.Tensor, data: TensorDict) -> torch.Tensor:
"""
Convert NestedTensor from no-padding to right padding format.
Args:
nested_tensor: NestedTensor with no-padding format
data: TensorDict with
- Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids"
- NonTensorData includes "max_seq_len", "max_response_len", "indices"
Returns:
values: regular tensor right padded to max_response_len
"""
assert "indices" in data, "indices is required in left-right padding data"
assert "max_seq_len" in data, "max_seq_len is required in left-right padding data"
assert "max_response_len" in data, "max_response_len is required in left-right padding data"
indices = tu.get_non_tensor_data(data=data, key="indices", default=None)
max_seq_len = tu.get_non_tensor_data(data=data, key="max_seq_len", default=2048)
max_response_len = tu.get_non_tensor_data(data=data, key="max_response_len", default=1024)
batch_size = nested_tensor.size(0)
values = nested_tensor.values()
full_values = pad_input(
hidden_states=values.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=max_seq_len,
)
values = full_values.squeeze(-1)[:, -max_response_len - 1 : -1] # (bsz, response_length)
return values