[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 2 (#3567)

### What does this PR do?

This PR continues the work started in PR #2733, it adds support for
variable sequence lengths in MultiTurnSFTDataset by introducing a
`no_padding` option for the pad_mode. When this mode is active,
sequences are not padded to a fixed length.
- Implement no-padding mode for FSDP engine using nested tensors in sft
trainer
- Add test for no-padding mode both enable/disable use_remove_padding
- Fix FSDP2 gradnorm issue

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: zhangchi.usc1992 <zhangchi.usc1992@bytedance.com>
This commit is contained in:
Houmin Wei
2025-09-24 17:12:31 +08:00
committed by GitHub
parent 1985eb14ff
commit 69b0127b74
12 changed files with 448 additions and 96 deletions

View File

@ -29,6 +29,10 @@ PP_SIZE=${PP_SIZE:-1}
VPP_SIZE=${VPP_SIZE:-null}
CP_SIZE=${CP_SIZE:-1}
PAD_MODE=${PAD_MODE:-left_right}
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
FSDP_ENGINE_CONFIG="\
engine=${backend} \
optim=${backend} \
@ -63,11 +67,11 @@ MEGATRON_ENGINE_CONFIG="\
if [ "$backend" = "fsdp" ]; then
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
echo "Using fsdp engine"
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
else
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
echo "Using megatron engine"
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
fi
mkdir -p "${ckpts_home}"
@ -78,12 +82,13 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
data.train_batch_size=256 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.pad_mode=left_right \
data.pad_mode=${PAD_MODE} \
data.truncation=error \
data.use_dynamic_bsz=True \
data.max_token_len_per_gpu=8192 \
data.messages_key=messages \
model.path=$MODEL_PATH \
model.use_remove_padding=${USE_REMOVE_PADDING} \
${ENGINE_CONFIG} \
trainer.test_freq=after_each_epoch \
trainer.save_freq=-1 \

View File

@ -9,26 +9,43 @@ echo "run with single gpu as golden"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test with fsdp 1
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp"
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp"
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test use_remove_padding and pad_mode left_right/no_padding
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right use_remove_padding False"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test with fsdp 2
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode left_right"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode no_padding"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# TODO: toggle the follow tests when the grad norm of fsdp is fixed
# echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
# BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
# BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
# test with megatron
echo "run with tp1 pp1 cp1 num_gpus1"

View File

@ -88,6 +88,45 @@ def test_tensor_dict_constructor():
assert data["name"] == "abdce"
def test_index_select_tensor_dict():
vocab_size = 128
a = torch.randint(low=0, high=vocab_size, size=(11,))
b = torch.randint(low=0, high=vocab_size, size=(13,))
c = torch.randint(low=0, high=vocab_size, size=(12,))
d = torch.randint(low=0, high=vocab_size, size=(15,))
input_ids = [a, b, c, d]
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
padded_tensor = torch.randn(4, 10)
non_tensor_dict = {"global_batch_size": "4"}
data = tu.get_tensordict(
tensor_dict={
"input_ids": input_ids,
"padded_tensor": padded_tensor,
},
non_tensor_dict=non_tensor_dict,
)
assert data.batch_size == torch.Size([4])
# test index select
indices = torch.tensor([1, 3])
selected_data = tu.index_select_tensor_dict(data, indices)
assert selected_data.batch_size == torch.Size([2])
target_input_ids = torch.nested.as_nested_tensor([input_ids[idx] for idx in indices], layout=torch.jagged)
target_select_data = tu.get_tensordict(
tensor_dict={
"input_ids": target_input_ids,
"padded_tensor": padded_tensor[indices],
},
non_tensor_dict=non_tensor_dict,
)
tu.assert_tensordict_eq(selected_data, target_select_data)
def test_tensordict_with_images():
# each sample contains a sequence with multiple images of different sizes
vocab_size = 128
@ -173,6 +212,37 @@ def test_tensordict_eq():
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data, data2)
tensor_list = [
torch.tensor([1, 2, 3, 3, 2]),
torch.tensor([4, 5]),
torch.tensor([7, 8, 10, 14]),
torch.tensor([10, 11, 12]),
torch.tensor([13, 14, 15, 18]),
torch.tensor([16, 17]),
]
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
data3 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
tensor_list[0] = torch.tensor([1, 2, 3, 3, 2])
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data4 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
tu.assert_tensordict_eq(data3, data4)
tensor_list[0] = torch.tensor([1, 2, 4])
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data5 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data3, data5)
tensor_list[0] = torch.tensor([4, 5])
tensor_list[1] = torch.tensor([1, 2, 3, 3, 2])
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data6 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data3, data6)
def test_tensor_dict_make_iterator():
obs = torch.tensor([1, 2, 3, 4, 5, 6])

View File

@ -32,6 +32,7 @@ from tqdm import tqdm
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint import CheckpointHandler
from verl.utils.dataset.dataset_utils import SFTTensorCollator
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
from verl.utils.distributed import destroy_global_process_group
@ -167,11 +168,13 @@ class SFTTrainer:
self.global_batch_size = config.data.train_batch_size
self.train_batch_size_per_dp = self.global_batch_size // dp_size
self.collate_fn = SFTTensorCollator(config.data.pad_mode)
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.train_batch_size_per_dp,
sampler=self.train_sampler,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True,
drop_last=True,
@ -185,6 +188,7 @@ class SFTTrainer:
dataset=self.val_dataset,
batch_size=self.train_batch_size_per_dp,
sampler=self.val_sampler,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True,
drop_last=True,
@ -227,11 +231,14 @@ class SFTTrainer:
start_epoch = global_step // self.steps_per_epoch
meta_info = {
"use_remove_padding": self.config.model.use_remove_padding,
"use_dynamic_bsz": self.config.data.use_dynamic_bsz,
"max_token_len_per_gpu": self.config.data.max_token_len_per_gpu,
"micro_batch_size_per_gpu": self.config.data.micro_batch_size_per_gpu,
"temperature": 1.0,
"global_batch_size": self.global_batch_size,
"pad_mode": self.config.data.pad_mode,
"pad_token_id": self.model_config.tokenizer.pad_token_id,
}
train_time = 0
@ -263,7 +270,12 @@ class SFTTrainer:
loss = torch.mean(torch.tensor(metrics["loss"], device=self.device_name))
# mean over dp group
batch_seqlens = data["attention_mask"].sum(dim=-1).to(self.device_name) # (global_bsz // dp)
is_nested = data["input_ids"].is_nested
if is_nested:
batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff()
else:
batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1)
batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp)
output_tensor = torch.randint(
0,

View File

@ -0,0 +1,70 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
import torch
class DatasetPadMode(str, Enum):
"""Padding mode for dataset"""
RIGHT = "right"
LEFT_RIGHT = "left_right"
NO_PADDING = "no_padding"
class SFTTensorCollator:
"""
A custom collate_fn that handles batching of sequences.
1. for variable-length sequences, convert them into NestedTensors.
2. for fixed-length sequences, use default_collate.
"""
def __init__(self, pad_mode: DatasetPadMode = DatasetPadMode.LEFT_RIGHT):
self.pad_mode = pad_mode
def __call__(self, batch: list[dict[str, any]]) -> dict[str, any]:
if self.pad_mode == DatasetPadMode.NO_PADDING:
return self.collate_variable_batch(batch)
elif self.pad_mode in [DatasetPadMode.RIGHT, DatasetPadMode.LEFT_RIGHT]:
from torch.utils.data import default_collate
return default_collate(batch)
else:
raise NotImplementedError(f"pad_mode {self.pad_mode} not implemented")
def collate_variable_batch(self, batch: list[dict[str, any]]) -> dict[str, any]:
"""
Collates a list of samples into a single batch.
Args:
batch: A list of dictionary samples from the dataset.
Returns:
A dictionary representing the batched data, with variable-length
sequences converted to NestedTensors.
"""
final_batch = {}
tensor_keys = [key for key in batch[0].keys() if isinstance(batch[0][key], torch.Tensor)]
# Handle tensor values by creating a NestedTensor.
for key in tensor_keys:
tensors = [item[key] for item in batch]
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
return final_batch

View File

@ -27,6 +27,7 @@ from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from verl.utils import hf_tokenizer
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.torch_functional import pad_sequence_to_length, postprocess_data
@ -54,8 +55,8 @@ class MultiTurnSFTDataset(Dataset):
# Set defaults and extract parameters from config if provided
config = config or {}
self.pad_mode = config.get("pad_mode", "right")
assert self.pad_mode in ["right", "left_right"], (
f"Expect pad_mode to be 'right' or 'left_right'. Got {self.pad_mode}"
assert self.pad_mode in ["right", "left_right", "no_padding"], (
f"Expect pad_mode to be 'right', 'left_right' or 'no_padding'. Got {self.pad_mode}"
)
self.truncation = config.get("truncation", "error")
# for right padding
@ -328,7 +329,7 @@ class MultiTurnSFTDataset(Dataset):
sequence_length = input_ids.shape[0]
# Handle sequence length
if self.pad_mode == "right":
if self.pad_mode == DatasetPadMode.RIGHT:
if sequence_length < self.max_length:
# Pad sequences
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
@ -364,7 +365,7 @@ class MultiTurnSFTDataset(Dataset):
"position_ids": position_ids,
"loss_mask": loss_mask,
}
elif self.pad_mode == "left_right":
elif self.pad_mode == DatasetPadMode.LEFT_RIGHT:
assert self.truncation == "error", "Only support error truncation for left_right pad mode"
prompt_str = self.tokenizer.apply_chat_template(
messages[:prompt_message_length],
@ -426,3 +427,16 @@ class MultiTurnSFTDataset(Dataset):
"responses": response_ids,
"response_mask": response_loss_mask,
}
elif self.pad_mode == DatasetPadMode.NO_PADDING:
# truncate input_ids if it is longer than max_length
if len(input_ids) > self.max_length:
input_ids = input_ids[: self.max_length]
loss_mask = loss_mask[: self.max_length]
# create position IDs
position_ids = torch.arange(len(input_ids), dtype=torch.long)
# return nested tensor with out padding
return {
"input_ids": input_ids,
"position_ids": position_ids,
"loss_mask": loss_mask,
}

View File

@ -20,6 +20,7 @@ import torch
from torch import distributed as dist
from verl.protocol import DataProto
from verl.utils import tensordict_utils as tu
from verl.utils.device import get_device_name
@ -273,11 +274,17 @@ def rearrange_micro_batches(
List[List[int]]: index lists mapping each micro-batch back to original positions.
"""
# this is per local micro_bsz
max_seq_len = batch["attention_mask"].shape[-1]
input_ids = batch["input_ids"]
if input_ids.is_nested:
seq_len_effective: torch.Tensor = input_ids.offsets().diff()
max_seq_len = max(seq_len_effective)
else:
max_seq_len = batch["attention_mask"].shape[-1]
seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
assert max_token_len >= max_seq_len, (
f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"
)
seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
total_seqlen = seq_len_effective.sum().item()
# NOTE: num_microbatches <= batch_size, so take the min of this two.
num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
@ -309,11 +316,7 @@ def rearrange_micro_batches(
micro_batches = []
for partition in micro_bsz_idx:
curr_micro_batch = []
for idx in partition:
curr_micro_batch.append(batch[idx : idx + 1])
curr_micro_batch = torch.cat(curr_micro_batch)
curr_micro_batch = tu.index_select_tensor_dict(batch, partition)
micro_batches.append(curr_micro_batch)
return micro_batches, micro_bsz_idx
@ -388,6 +391,14 @@ def restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -
torch.Tensor: The restored data.
"""
indices = list(chain.from_iterable(batch_idx_list))
assert len(indices) == data.size(0), f"{len(indices)} vs. {data.size()}"
batch_size = data.shape[0]
assert len(indices) == batch_size, f"{len(indices)} vs. {batch_size}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
return data[revert_indices]
if data.is_nested:
tensors = [data[i] for i in revert_indices]
reverted_data = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
else:
reverted_data = data[revert_indices]
return reverted_data

View File

@ -69,13 +69,16 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
"Passing a list makes the data NonTensorStack, "
"which doesn't support torch.Tensor. Please convert to numpy first"
)
assert isinstance(val, torch.Tensor | list)
if batch_size is None:
batch_size = len(val)
batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)
else:
assert len(val) == batch_size
val_batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)
assert val_batch_size == batch_size, (
f"Batch size of tensor {key} is not consistent with other tensors. "
f"Expected {batch_size}, got {val_batch_size}"
)
if batch_size is None:
batch_size = []
@ -89,6 +92,35 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
return TensorDict(source=tensor_dict, batch_size=batch_size)
def index_select_tensor_dict(batch: TensorDict, indices: torch.Tensor | list[int]) -> TensorDict:
"""Index a tensor dict with a tensor of indices."""
if isinstance(indices, list):
indices = torch.tensor(indices)
assert indices.dim() == 1, "indices must be a 1D tensor"
data_dict = {}
batch_size = indices.shape[0]
if batch is not None:
for key, tensor in batch.items():
if isinstance(tensor, torch.Tensor) and not tensor.is_nested:
data_dict[key] = tensor[indices]
elif isinstance(tensor, torch.Tensor) and tensor.is_nested:
data_dict[key] = torch.nested.as_nested_tensor([tensor[idx] for idx in indices], layout=torch.jagged)
else:
# This handles NonTensorStack (indexable by batch dim) and NonTensorData (scalar metadata).
if tensor.shape:
data_dict[key] = tensor[indices]
else:
data_dict[key] = tensor
selected_batch = TensorDict(source=data_dict, batch_size=batch_size)
else:
selected_batch = None
return selected_batch
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
"""Union two tensordicts."""
assert tensor_dict1.batch_size == tensor_dict2.batch_size, (
@ -147,7 +179,16 @@ def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):
assert type(val) is type(val2), f"The type of {key} must be the same. Got {type(val)} vs {type(val2)}"
if isinstance(val, torch.Tensor):
assert torch.all(torch.eq(val, val2)).item()
if val.is_nested:
assert val.is_nested and val2.is_nested, (
f"Both tensors must be nested tensors. {val.is_nested=}, {val2.is_nested=}"
)
t1, t2 = val.unbind(), val2.unbind()
assert len(t1) == len(t2), f"Nested tensor should have the same lengths. {len(t1)=} vs {len(t2)=}"
for c1, c2 in zip(t1, t2, strict=True):
assert torch.equal(c1, c2), f"Nested tensor components have different values. {c1=} vs {c2=}"
else:
assert torch.all(torch.eq(val, val2)).item()
else:
assert val == val2

View File

@ -23,6 +23,7 @@ import os
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DTensor
import verl.utils.torch_functional as verl_F
from verl import DataProto
@ -284,6 +285,9 @@ class DataParallelPPOActor(BasePPOActor):
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
# if grad_norm is not finite, skip the update
if not torch.isfinite(grad_norm):
print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}")

View File

@ -35,6 +35,7 @@ from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.utils import tensordict_utils as tu
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import (
get_device_id,
@ -483,7 +484,7 @@ class FSDPEngine(BaseEngine):
if not forward_only:
global_bsz = data["global_batch_size"]
local_micro_bsz = micro_batch["input_ids"].shape[0]
local_micro_bsz = micro_batch.batch_size[0]
# metrics contain the output, loss is dummy
loss_scale_factor = local_micro_bsz / (global_bsz / self.get_data_parallel_size())
# scale loss
@ -522,6 +523,9 @@ class FSDPEngine(BaseEngine):
self.module.parameters(), max_norm=self.optimizer_config.clip_grad
)
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
# if grad_norm is not finite, skip the update
if not torch.isfinite(grad_norm):
print(f"WARN: grad_norm is not finite: {grad_norm}")
@ -697,6 +701,7 @@ class EngineTrainModeCtx:
class FSDPEngineWithLMHead(FSDPEngine):
def prepare_model_inputs(self, micro_batch: TensorDict):
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
temperature = micro_batch["temperature"]
@ -707,8 +712,8 @@ class FSDPEngineWithLMHead(FSDPEngine):
multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])
input_ids = micro_batch["input_ids"]
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
@ -716,29 +721,37 @@ class FSDPEngineWithLMHead(FSDPEngine):
output_args = {}
if use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
if pad_mode == DatasetPadMode.NO_PADDING:
input_ids_rmpad = input_ids.values().unsqueeze(0) # (1, total_nnz)
position_ids_rmpad = position_ids.values().unsqueeze(0) # (1, total_nnz)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
attention_mask = micro_batch["attention_mask"]
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
output_args["indices"] = indices
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
if "image_bound" in multi_modal_inputs:
from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo
multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(
input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs
)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
if "image_bound" in multi_modal_inputs:
from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo
multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(
input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs
)
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
@ -768,9 +781,7 @@ class FSDPEngineWithLMHead(FSDPEngine):
output_args["pad_size"] = pad_size
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled
output_args["indices"] = indices
# only pass input_ids and position_ids to enable flash_attn_varlen
@ -781,11 +792,47 @@ class FSDPEngineWithLMHead(FSDPEngine):
}
else:
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if pad_mode == DatasetPadMode.NO_PADDING:
input_ids = micro_batch["input_ids"]
position_ids = micro_batch["position_ids"]
loss_mask = micro_batch["loss_mask"]
pad_token_id = tu.get_non_tensor_data(data=micro_batch, key="pad_token_id", default=0)
batch_size = micro_batch.batch_size[0]
seq_len_effective = input_ids.offsets().diff()
max_seq_len = max(seq_len_effective)
input_ids_rmpad_rolled = torch.roll(input_ids.values(), shifts=-1, dims=0)
output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled
input_ids = torch.nested.to_padded_tensor(
input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len)
)
position_ids = torch.nested.to_padded_tensor(
position_ids, padding=0, output_size=(batch_size, max_seq_len)
)
attention_mask_list = [torch.ones_like(t, dtype=torch.int32) for t in loss_mask]
attention_mask = torch.nested.as_nested_tensor(attention_mask_list, layout=torch.jagged)
attention_mask = torch.nested.to_padded_tensor(
attention_mask, padding=0, output_size=(batch_size, max_seq_len)
)
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
attention_mask = micro_batch["attention_mask"]
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
extra_args = {}
if use_fused_kernels:
@ -799,18 +846,16 @@ class FSDPEngineWithLMHead(FSDPEngine):
def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
temperature = micro_batch["temperature"]
calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False)
model_output = {}
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
response_length = micro_batch["responses"].size(-1)
if use_remove_padding:
input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"]
indices = output_args["indices"]
if use_fused_kernels:
log_probs = output.log_probs.squeeze(0) # (total_nnz,)
@ -856,44 +901,78 @@ class FSDPEngineWithLMHead(FSDPEngine):
unpad_dim=0,
padding_size=pad_size,
)
# pad back to (bsz, seqlen)
if calculate_entropy:
full_entropy = pad_input(
hidden_states=entropy_rmpad.unsqueeze(-1),
if pad_mode == DatasetPadMode.NO_PADDING:
cu_seqlens = input_ids.offsets()
# (bsz, j1) for each sample, is the length of each sample: [real_prompt length + real_response length]
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
if calculate_entropy:
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
indices = output_args["indices"]
response_length = micro_batch["responses"].size(-1)
batch_size, seqlen = input_ids.shape
full_log_probs = pad_input(
hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen,
)
full_log_probs = pad_input(
hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen,
)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
# only return response part:
if calculate_entropy:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
# pad back to (bsz, seqlen)
if calculate_entropy:
full_entropy = pad_input(
hidden_states=entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen,
)
# only return response part:
if calculate_entropy:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
else: # not using rmpad and no ulysses sp
response_length = tu.get_non_tensor_data(data=micro_batch, key="max_response_length", default=1024)
if use_fused_kernels:
log_probs = output.log_probs[:, -response_length - 1 : -1]
entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length)
else:
logits = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
log_probs = logprobs_from_logits(logits, micro_batch.batch["responses"])
if calculate_entropy:
if not self.engine_config.entropy_checkpointing:
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
entropy = verl_F.entropy_from_logits(logits)
else:
entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)
model_output = {"log_probs": log_probs}
if pad_mode == DatasetPadMode.NO_PADDING:
cu_seqlens = input_ids.offsets()
seq_lengths = cu_seqlens.diff()
starts = torch.zeros_like(seq_lengths, dtype=torch.int64)
logits = torch.nested.narrow(logits, 1, starts, seq_lengths, layout=torch.jagged)
logits_rmpad = torch.cat([t for t in logits.unbind()])
input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"]
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
# (bsz, j1) for each sample, length of each sample: [real_prompt_length + real_response_length]
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
if calculate_entropy:
entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged)
entropy_rmpad = torch.cat([t for t in entropy.unbind()])
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
log_probs = logprobs_from_logits(logits, micro_batch["responses"])
if calculate_entropy:
entropy = entropy[:, -response_length - 1 : -1] # (bsz, response_length)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
model_output["log_probs"] = log_probs
if calculate_entropy:
model_output["entropy"] = entropy

View File

@ -17,6 +17,7 @@ import torch
from tensordict import TensorDict
from verl.utils import tensordict_utils as tu
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import rearrange_micro_batches, restore_dynamic_batch
@ -65,6 +66,7 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
"""
use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True)
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
# losses_reduced is a list of dict containing outputs for each micro-batch
# reorder entropy and outputs. Return None for other pp ranks
@ -87,7 +89,14 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
# concat results from micro batches
for key, val in model_output.items():
model_output[key] = torch.cat(model_output[key], dim=0)
if pad_mode == DatasetPadMode.NO_PADDING:
tensors = [tensor for nt in model_output[key] for tensor in nt.unbind()]
model_output[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
model_output[key] = torch.cat(model_output[key], dim=0)
else:
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
# reverse with dynamic bsz
if use_dynamic_bsz:
model_output[key] = restore_dynamic_batch(model_output[key], indices)

View File

@ -17,14 +17,34 @@ import torch
from tensordict import TensorDict
from verl.trainer.ppo.core_algos import agg_loss, compute_value_loss, get_policy_loss_fn, kl_penalty
from verl.utils import tensordict_utils as tu
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.torch_functional import masked_mean
from verl.workers.config import ActorConfig, CriticConfig
def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):
log_prob = model_output["log_probs"] # [bsz, response_length]
response_mask = data["response_mask"].to(bool)
loss = -torch.mean(log_prob * response_mask)
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
log_prob = model_output["log_probs"]
if pad_mode == DatasetPadMode.NO_PADDING:
# log_prob and loss mask are nested tensors of shape [bsz, j1]
# for each sample, loss mask shape is [1, prompt_length + response_length]
loss_mask = data["loss_mask"]
log_prob_flatten = log_prob.values()
cu_seqlens = log_prob.offsets()
loss_mask_flatten = loss_mask.values()
# left-shift the loss mask by one token to align with log_prob
loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0)
loss_mask_flatten[cu_seqlens[1:] - 1] = 0
loss = -masked_mean(log_prob_flatten, loss_mask_flatten)
else:
response_mask = data["response_mask"].to(bool)
loss = -masked_mean(log_prob, response_mask)
return loss, {"loss": loss.detach().item()}