mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 2 (#3567)
### What does this PR do? This PR continues the work started in PR #2733, it adds support for variable sequence lengths in MultiTurnSFTDataset by introducing a `no_padding` option for the pad_mode. When this mode is active, sequences are not padded to a fixed length. - Implement no-padding mode for FSDP engine using nested tensors in sft trainer - Add test for no-padding mode both enable/disable use_remove_padding - Fix FSDP2 gradnorm issue ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: zhangchi.usc1992 <zhangchi.usc1992@bytedance.com>
This commit is contained in:
@ -29,6 +29,10 @@ PP_SIZE=${PP_SIZE:-1}
|
||||
VPP_SIZE=${VPP_SIZE:-null}
|
||||
CP_SIZE=${CP_SIZE:-1}
|
||||
|
||||
PAD_MODE=${PAD_MODE:-left_right}
|
||||
|
||||
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
|
||||
|
||||
FSDP_ENGINE_CONFIG="\
|
||||
engine=${backend} \
|
||||
optim=${backend} \
|
||||
@ -63,11 +67,11 @@ MEGATRON_ENGINE_CONFIG="\
|
||||
if [ "$backend" = "fsdp" ]; then
|
||||
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
|
||||
echo "Using fsdp engine"
|
||||
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}
|
||||
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
|
||||
else
|
||||
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
|
||||
echo "Using megatron engine"
|
||||
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}
|
||||
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
|
||||
fi
|
||||
|
||||
mkdir -p "${ckpts_home}"
|
||||
@ -78,12 +82,13 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
|
||||
data.train_batch_size=256 \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=1024 \
|
||||
data.pad_mode=left_right \
|
||||
data.pad_mode=${PAD_MODE} \
|
||||
data.truncation=error \
|
||||
data.use_dynamic_bsz=True \
|
||||
data.max_token_len_per_gpu=8192 \
|
||||
data.messages_key=messages \
|
||||
model.path=$MODEL_PATH \
|
||||
model.use_remove_padding=${USE_REMOVE_PADDING} \
|
||||
${ENGINE_CONFIG} \
|
||||
trainer.test_freq=after_each_epoch \
|
||||
trainer.save_freq=-1 \
|
||||
|
@ -9,26 +9,43 @@ echo "run with single gpu as golden"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
# test with fsdp 1
|
||||
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp"
|
||||
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp"
|
||||
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
# test use_remove_padding and pad_mode left_right/no_padding
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right use_remove_padding False"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
|
||||
# test with fsdp 2
|
||||
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode left_right"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode no_padding"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
# TODO: toggle the follow tests when the grad norm of fsdp is fixed
|
||||
# echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
|
||||
# BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
# echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
|
||||
# BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
# BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
# BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
|
||||
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
|
||||
|
||||
# test with megatron
|
||||
echo "run with tp1 pp1 cp1 num_gpus1"
|
||||
|
@ -88,6 +88,45 @@ def test_tensor_dict_constructor():
|
||||
assert data["name"] == "abdce"
|
||||
|
||||
|
||||
def test_index_select_tensor_dict():
|
||||
vocab_size = 128
|
||||
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
||||
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
||||
c = torch.randint(low=0, high=vocab_size, size=(12,))
|
||||
d = torch.randint(low=0, high=vocab_size, size=(15,))
|
||||
input_ids = [a, b, c, d]
|
||||
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
||||
|
||||
padded_tensor = torch.randn(4, 10)
|
||||
non_tensor_dict = {"global_batch_size": "4"}
|
||||
|
||||
data = tu.get_tensordict(
|
||||
tensor_dict={
|
||||
"input_ids": input_ids,
|
||||
"padded_tensor": padded_tensor,
|
||||
},
|
||||
non_tensor_dict=non_tensor_dict,
|
||||
)
|
||||
|
||||
assert data.batch_size == torch.Size([4])
|
||||
|
||||
# test index select
|
||||
indices = torch.tensor([1, 3])
|
||||
selected_data = tu.index_select_tensor_dict(data, indices)
|
||||
|
||||
assert selected_data.batch_size == torch.Size([2])
|
||||
|
||||
target_input_ids = torch.nested.as_nested_tensor([input_ids[idx] for idx in indices], layout=torch.jagged)
|
||||
target_select_data = tu.get_tensordict(
|
||||
tensor_dict={
|
||||
"input_ids": target_input_ids,
|
||||
"padded_tensor": padded_tensor[indices],
|
||||
},
|
||||
non_tensor_dict=non_tensor_dict,
|
||||
)
|
||||
tu.assert_tensordict_eq(selected_data, target_select_data)
|
||||
|
||||
|
||||
def test_tensordict_with_images():
|
||||
# each sample contains a sequence with multiple images of different sizes
|
||||
vocab_size = 128
|
||||
@ -173,6 +212,37 @@ def test_tensordict_eq():
|
||||
with pytest.raises(AssertionError):
|
||||
tu.assert_tensordict_eq(data, data2)
|
||||
|
||||
tensor_list = [
|
||||
torch.tensor([1, 2, 3, 3, 2]),
|
||||
torch.tensor([4, 5]),
|
||||
torch.tensor([7, 8, 10, 14]),
|
||||
torch.tensor([10, 11, 12]),
|
||||
torch.tensor([13, 14, 15, 18]),
|
||||
torch.tensor([16, 17]),
|
||||
]
|
||||
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
||||
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
||||
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
||||
data3 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
||||
|
||||
tensor_list[0] = torch.tensor([1, 2, 3, 3, 2])
|
||||
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
||||
data4 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
||||
tu.assert_tensordict_eq(data3, data4)
|
||||
|
||||
tensor_list[0] = torch.tensor([1, 2, 4])
|
||||
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
||||
data5 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
||||
with pytest.raises(AssertionError):
|
||||
tu.assert_tensordict_eq(data3, data5)
|
||||
|
||||
tensor_list[0] = torch.tensor([4, 5])
|
||||
tensor_list[1] = torch.tensor([1, 2, 3, 3, 2])
|
||||
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
||||
data6 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
||||
with pytest.raises(AssertionError):
|
||||
tu.assert_tensordict_eq(data3, data6)
|
||||
|
||||
|
||||
def test_tensor_dict_make_iterator():
|
||||
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||
|
@ -32,6 +32,7 @@ from tqdm import tqdm
|
||||
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.checkpoint import CheckpointHandler
|
||||
from verl.utils.dataset.dataset_utils import SFTTensorCollator
|
||||
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
|
||||
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
|
||||
from verl.utils.distributed import destroy_global_process_group
|
||||
@ -167,11 +168,13 @@ class SFTTrainer:
|
||||
|
||||
self.global_batch_size = config.data.train_batch_size
|
||||
self.train_batch_size_per_dp = self.global_batch_size // dp_size
|
||||
self.collate_fn = SFTTensorCollator(config.data.pad_mode)
|
||||
|
||||
self.train_dataloader = StatefulDataLoader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=self.train_batch_size_per_dp,
|
||||
sampler=self.train_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
@ -185,6 +188,7 @@ class SFTTrainer:
|
||||
dataset=self.val_dataset,
|
||||
batch_size=self.train_batch_size_per_dp,
|
||||
sampler=self.val_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
@ -227,11 +231,14 @@ class SFTTrainer:
|
||||
start_epoch = global_step // self.steps_per_epoch
|
||||
|
||||
meta_info = {
|
||||
"use_remove_padding": self.config.model.use_remove_padding,
|
||||
"use_dynamic_bsz": self.config.data.use_dynamic_bsz,
|
||||
"max_token_len_per_gpu": self.config.data.max_token_len_per_gpu,
|
||||
"micro_batch_size_per_gpu": self.config.data.micro_batch_size_per_gpu,
|
||||
"temperature": 1.0,
|
||||
"global_batch_size": self.global_batch_size,
|
||||
"pad_mode": self.config.data.pad_mode,
|
||||
"pad_token_id": self.model_config.tokenizer.pad_token_id,
|
||||
}
|
||||
|
||||
train_time = 0
|
||||
@ -263,7 +270,12 @@ class SFTTrainer:
|
||||
loss = torch.mean(torch.tensor(metrics["loss"], device=self.device_name))
|
||||
|
||||
# mean over dp group
|
||||
batch_seqlens = data["attention_mask"].sum(dim=-1).to(self.device_name) # (global_bsz // dp)
|
||||
is_nested = data["input_ids"].is_nested
|
||||
if is_nested:
|
||||
batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff()
|
||||
else:
|
||||
batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1)
|
||||
batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp)
|
||||
|
||||
output_tensor = torch.randint(
|
||||
0,
|
||||
|
70
verl/utils/dataset/dataset_utils.py
Normal file
70
verl/utils/dataset/dataset_utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DatasetPadMode(str, Enum):
|
||||
"""Padding mode for dataset"""
|
||||
|
||||
RIGHT = "right"
|
||||
LEFT_RIGHT = "left_right"
|
||||
NO_PADDING = "no_padding"
|
||||
|
||||
|
||||
class SFTTensorCollator:
|
||||
"""
|
||||
A custom collate_fn that handles batching of sequences.
|
||||
1. for variable-length sequences, convert them into NestedTensors.
|
||||
2. for fixed-length sequences, use default_collate.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: DatasetPadMode = DatasetPadMode.LEFT_RIGHT):
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def __call__(self, batch: list[dict[str, any]]) -> dict[str, any]:
|
||||
if self.pad_mode == DatasetPadMode.NO_PADDING:
|
||||
return self.collate_variable_batch(batch)
|
||||
elif self.pad_mode in [DatasetPadMode.RIGHT, DatasetPadMode.LEFT_RIGHT]:
|
||||
from torch.utils.data import default_collate
|
||||
|
||||
return default_collate(batch)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {self.pad_mode} not implemented")
|
||||
|
||||
def collate_variable_batch(self, batch: list[dict[str, any]]) -> dict[str, any]:
|
||||
"""
|
||||
Collates a list of samples into a single batch.
|
||||
|
||||
Args:
|
||||
batch: A list of dictionary samples from the dataset.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the batched data, with variable-length
|
||||
sequences converted to NestedTensors.
|
||||
"""
|
||||
|
||||
final_batch = {}
|
||||
|
||||
tensor_keys = [key for key in batch[0].keys() if isinstance(batch[0][key], torch.Tensor)]
|
||||
|
||||
# Handle tensor values by creating a NestedTensor.
|
||||
for key in tensor_keys:
|
||||
tensors = [item[key] for item in batch]
|
||||
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
|
||||
|
||||
return final_batch
|
@ -27,6 +27,7 @@ from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from verl.utils import hf_tokenizer
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
from verl.utils.torch_functional import pad_sequence_to_length, postprocess_data
|
||||
@ -54,8 +55,8 @@ class MultiTurnSFTDataset(Dataset):
|
||||
# Set defaults and extract parameters from config if provided
|
||||
config = config or {}
|
||||
self.pad_mode = config.get("pad_mode", "right")
|
||||
assert self.pad_mode in ["right", "left_right"], (
|
||||
f"Expect pad_mode to be 'right' or 'left_right'. Got {self.pad_mode}"
|
||||
assert self.pad_mode in ["right", "left_right", "no_padding"], (
|
||||
f"Expect pad_mode to be 'right', 'left_right' or 'no_padding'. Got {self.pad_mode}"
|
||||
)
|
||||
self.truncation = config.get("truncation", "error")
|
||||
# for right padding
|
||||
@ -328,7 +329,7 @@ class MultiTurnSFTDataset(Dataset):
|
||||
|
||||
sequence_length = input_ids.shape[0]
|
||||
# Handle sequence length
|
||||
if self.pad_mode == "right":
|
||||
if self.pad_mode == DatasetPadMode.RIGHT:
|
||||
if sequence_length < self.max_length:
|
||||
# Pad sequences
|
||||
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
|
||||
@ -364,7 +365,7 @@ class MultiTurnSFTDataset(Dataset):
|
||||
"position_ids": position_ids,
|
||||
"loss_mask": loss_mask,
|
||||
}
|
||||
elif self.pad_mode == "left_right":
|
||||
elif self.pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
assert self.truncation == "error", "Only support error truncation for left_right pad mode"
|
||||
prompt_str = self.tokenizer.apply_chat_template(
|
||||
messages[:prompt_message_length],
|
||||
@ -426,3 +427,16 @@ class MultiTurnSFTDataset(Dataset):
|
||||
"responses": response_ids,
|
||||
"response_mask": response_loss_mask,
|
||||
}
|
||||
elif self.pad_mode == DatasetPadMode.NO_PADDING:
|
||||
# truncate input_ids if it is longer than max_length
|
||||
if len(input_ids) > self.max_length:
|
||||
input_ids = input_ids[: self.max_length]
|
||||
loss_mask = loss_mask[: self.max_length]
|
||||
# create position IDs
|
||||
position_ids = torch.arange(len(input_ids), dtype=torch.long)
|
||||
# return nested tensor with out padding
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"loss_mask": loss_mask,
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
from verl.protocol import DataProto
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.device import get_device_name
|
||||
|
||||
|
||||
@ -273,11 +274,17 @@ def rearrange_micro_batches(
|
||||
List[List[int]]: index lists mapping each micro-batch back to original positions.
|
||||
"""
|
||||
# this is per local micro_bsz
|
||||
max_seq_len = batch["attention_mask"].shape[-1]
|
||||
input_ids = batch["input_ids"]
|
||||
if input_ids.is_nested:
|
||||
seq_len_effective: torch.Tensor = input_ids.offsets().diff()
|
||||
max_seq_len = max(seq_len_effective)
|
||||
else:
|
||||
max_seq_len = batch["attention_mask"].shape[-1]
|
||||
seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
|
||||
|
||||
assert max_token_len >= max_seq_len, (
|
||||
f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"
|
||||
)
|
||||
seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
|
||||
total_seqlen = seq_len_effective.sum().item()
|
||||
# NOTE: num_microbatches <= batch_size, so take the min of this two.
|
||||
num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
|
||||
@ -309,11 +316,7 @@ def rearrange_micro_batches(
|
||||
micro_batches = []
|
||||
|
||||
for partition in micro_bsz_idx:
|
||||
curr_micro_batch = []
|
||||
for idx in partition:
|
||||
curr_micro_batch.append(batch[idx : idx + 1])
|
||||
curr_micro_batch = torch.cat(curr_micro_batch)
|
||||
|
||||
curr_micro_batch = tu.index_select_tensor_dict(batch, partition)
|
||||
micro_batches.append(curr_micro_batch)
|
||||
|
||||
return micro_batches, micro_bsz_idx
|
||||
@ -388,6 +391,14 @@ def restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -
|
||||
torch.Tensor: The restored data.
|
||||
"""
|
||||
indices = list(chain.from_iterable(batch_idx_list))
|
||||
assert len(indices) == data.size(0), f"{len(indices)} vs. {data.size()}"
|
||||
batch_size = data.shape[0]
|
||||
assert len(indices) == batch_size, f"{len(indices)} vs. {batch_size}"
|
||||
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
|
||||
return data[revert_indices]
|
||||
|
||||
if data.is_nested:
|
||||
tensors = [data[i] for i in revert_indices]
|
||||
reverted_data = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
|
||||
else:
|
||||
reverted_data = data[revert_indices]
|
||||
|
||||
return reverted_data
|
||||
|
@ -69,13 +69,16 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
|
||||
"Passing a list makes the data NonTensorStack, "
|
||||
"which doesn't support torch.Tensor. Please convert to numpy first"
|
||||
)
|
||||
|
||||
assert isinstance(val, torch.Tensor | list)
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = len(val)
|
||||
batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)
|
||||
else:
|
||||
assert len(val) == batch_size
|
||||
val_batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)
|
||||
assert val_batch_size == batch_size, (
|
||||
f"Batch size of tensor {key} is not consistent with other tensors. "
|
||||
f"Expected {batch_size}, got {val_batch_size}"
|
||||
)
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = []
|
||||
@ -89,6 +92,35 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
|
||||
return TensorDict(source=tensor_dict, batch_size=batch_size)
|
||||
|
||||
|
||||
def index_select_tensor_dict(batch: TensorDict, indices: torch.Tensor | list[int]) -> TensorDict:
|
||||
"""Index a tensor dict with a tensor of indices."""
|
||||
if isinstance(indices, list):
|
||||
indices = torch.tensor(indices)
|
||||
|
||||
assert indices.dim() == 1, "indices must be a 1D tensor"
|
||||
|
||||
data_dict = {}
|
||||
batch_size = indices.shape[0]
|
||||
|
||||
if batch is not None:
|
||||
for key, tensor in batch.items():
|
||||
if isinstance(tensor, torch.Tensor) and not tensor.is_nested:
|
||||
data_dict[key] = tensor[indices]
|
||||
elif isinstance(tensor, torch.Tensor) and tensor.is_nested:
|
||||
data_dict[key] = torch.nested.as_nested_tensor([tensor[idx] for idx in indices], layout=torch.jagged)
|
||||
else:
|
||||
# This handles NonTensorStack (indexable by batch dim) and NonTensorData (scalar metadata).
|
||||
if tensor.shape:
|
||||
data_dict[key] = tensor[indices]
|
||||
else:
|
||||
data_dict[key] = tensor
|
||||
selected_batch = TensorDict(source=data_dict, batch_size=batch_size)
|
||||
else:
|
||||
selected_batch = None
|
||||
|
||||
return selected_batch
|
||||
|
||||
|
||||
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
|
||||
"""Union two tensordicts."""
|
||||
assert tensor_dict1.batch_size == tensor_dict2.batch_size, (
|
||||
@ -147,7 +179,16 @@ def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):
|
||||
assert type(val) is type(val2), f"The type of {key} must be the same. Got {type(val)} vs {type(val2)}"
|
||||
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert torch.all(torch.eq(val, val2)).item()
|
||||
if val.is_nested:
|
||||
assert val.is_nested and val2.is_nested, (
|
||||
f"Both tensors must be nested tensors. {val.is_nested=}, {val2.is_nested=}"
|
||||
)
|
||||
t1, t2 = val.unbind(), val2.unbind()
|
||||
assert len(t1) == len(t2), f"Nested tensor should have the same lengths. {len(t1)=} vs {len(t2)=}"
|
||||
for c1, c2 in zip(t1, t2, strict=True):
|
||||
assert torch.equal(c1, c2), f"Nested tensor components have different values. {c1=} vs {c2=}"
|
||||
else:
|
||||
assert torch.all(torch.eq(val, val2)).item()
|
||||
else:
|
||||
assert val == val2
|
||||
|
||||
|
@ -23,6 +23,7 @@ import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from verl import DataProto
|
||||
@ -284,6 +285,9 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
|
||||
|
||||
if isinstance(grad_norm, DTensor):
|
||||
grad_norm = grad_norm.full_tensor()
|
||||
|
||||
# if grad_norm is not finite, skip the update
|
||||
if not torch.isfinite(grad_norm):
|
||||
print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}")
|
||||
|
@ -35,6 +35,7 @@ from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.activation_offload import enable_activation_offloading
|
||||
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.debug import log_gpu_memory_usage
|
||||
from verl.utils.device import (
|
||||
get_device_id,
|
||||
@ -483,7 +484,7 @@ class FSDPEngine(BaseEngine):
|
||||
|
||||
if not forward_only:
|
||||
global_bsz = data["global_batch_size"]
|
||||
local_micro_bsz = micro_batch["input_ids"].shape[0]
|
||||
local_micro_bsz = micro_batch.batch_size[0]
|
||||
# metrics contain the output, loss is dummy
|
||||
loss_scale_factor = local_micro_bsz / (global_bsz / self.get_data_parallel_size())
|
||||
# scale loss
|
||||
@ -522,6 +523,9 @@ class FSDPEngine(BaseEngine):
|
||||
self.module.parameters(), max_norm=self.optimizer_config.clip_grad
|
||||
)
|
||||
|
||||
if isinstance(grad_norm, DTensor):
|
||||
grad_norm = grad_norm.full_tensor()
|
||||
|
||||
# if grad_norm is not finite, skip the update
|
||||
if not torch.isfinite(grad_norm):
|
||||
print(f"WARN: grad_norm is not finite: {grad_norm}")
|
||||
@ -697,6 +701,7 @@ class EngineTrainModeCtx:
|
||||
class FSDPEngineWithLMHead(FSDPEngine):
|
||||
def prepare_model_inputs(self, micro_batch: TensorDict):
|
||||
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
|
||||
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
|
||||
temperature = micro_batch["temperature"]
|
||||
|
||||
@ -707,8 +712,8 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])
|
||||
|
||||
input_ids = micro_batch["input_ids"]
|
||||
attention_mask = micro_batch["attention_mask"]
|
||||
position_ids = micro_batch["position_ids"]
|
||||
|
||||
if position_ids.dim() == 3: # qwen2vl mrope
|
||||
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
|
||||
|
||||
@ -716,29 +721,37 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
output_args = {}
|
||||
|
||||
if use_remove_padding:
|
||||
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask
|
||||
) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
input_ids_rmpad = input_ids.values().unsqueeze(0) # (1, total_nnz)
|
||||
position_ids_rmpad = position_ids.values().unsqueeze(0) # (1, total_nnz)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
attention_mask = micro_batch["attention_mask"]
|
||||
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask
|
||||
) # input_ids_rmpad (total_nnz, ...)
|
||||
output_args["indices"] = indices
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
if position_ids.dim() == 3:
|
||||
position_ids_rmpad = (
|
||||
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
|
||||
.transpose(0, 1)
|
||||
.unsqueeze(1)
|
||||
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
|
||||
# unpad the position_ids to align the rotary
|
||||
if position_ids.dim() == 3:
|
||||
position_ids_rmpad = (
|
||||
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
|
||||
.transpose(0, 1)
|
||||
.unsqueeze(1)
|
||||
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
|
||||
else:
|
||||
position_ids_rmpad = index_first_axis(
|
||||
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
|
||||
).transpose(0, 1)
|
||||
|
||||
if "image_bound" in multi_modal_inputs:
|
||||
from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo
|
||||
|
||||
multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(
|
||||
input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs
|
||||
)
|
||||
else:
|
||||
position_ids_rmpad = index_first_axis(
|
||||
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
|
||||
).transpose(0, 1)
|
||||
|
||||
if "image_bound" in multi_modal_inputs:
|
||||
from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo
|
||||
|
||||
multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(
|
||||
input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs
|
||||
)
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
# for compute the log_prob
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
|
||||
@ -768,9 +781,7 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
output_args["pad_size"] = pad_size
|
||||
|
||||
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
|
||||
|
||||
output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled
|
||||
output_args["indices"] = indices
|
||||
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
|
||||
@ -781,11 +792,47 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
}
|
||||
|
||||
else:
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
input_ids = micro_batch["input_ids"]
|
||||
position_ids = micro_batch["position_ids"]
|
||||
loss_mask = micro_batch["loss_mask"]
|
||||
|
||||
pad_token_id = tu.get_non_tensor_data(data=micro_batch, key="pad_token_id", default=0)
|
||||
batch_size = micro_batch.batch_size[0]
|
||||
seq_len_effective = input_ids.offsets().diff()
|
||||
max_seq_len = max(seq_len_effective)
|
||||
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids.values(), shifts=-1, dims=0)
|
||||
output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled
|
||||
|
||||
input_ids = torch.nested.to_padded_tensor(
|
||||
input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len)
|
||||
)
|
||||
|
||||
position_ids = torch.nested.to_padded_tensor(
|
||||
position_ids, padding=0, output_size=(batch_size, max_seq_len)
|
||||
)
|
||||
|
||||
attention_mask_list = [torch.ones_like(t, dtype=torch.int32) for t in loss_mask]
|
||||
attention_mask = torch.nested.as_nested_tensor(attention_mask_list, layout=torch.jagged)
|
||||
attention_mask = torch.nested.to_padded_tensor(
|
||||
attention_mask, padding=0, output_size=(batch_size, max_seq_len)
|
||||
)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
attention_mask = micro_batch["attention_mask"]
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
extra_args = {}
|
||||
if use_fused_kernels:
|
||||
@ -799,18 +846,16 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
|
||||
def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
|
||||
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
|
||||
pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False)
|
||||
temperature = micro_batch["temperature"]
|
||||
calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False)
|
||||
|
||||
model_output = {}
|
||||
|
||||
input_ids = micro_batch["input_ids"]
|
||||
batch_size, seqlen = input_ids.shape
|
||||
|
||||
response_length = micro_batch["responses"].size(-1)
|
||||
|
||||
if use_remove_padding:
|
||||
input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"]
|
||||
indices = output_args["indices"]
|
||||
|
||||
if use_fused_kernels:
|
||||
log_probs = output.log_probs.squeeze(0) # (total_nnz,)
|
||||
@ -856,44 +901,78 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
unpad_dim=0,
|
||||
padding_size=pad_size,
|
||||
)
|
||||
# pad back to (bsz, seqlen)
|
||||
if calculate_entropy:
|
||||
full_entropy = pad_input(
|
||||
hidden_states=entropy_rmpad.unsqueeze(-1),
|
||||
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
cu_seqlens = input_ids.offsets()
|
||||
# (bsz, j1) for each sample, is the length of each sample: [real_prompt length + real_response length]
|
||||
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
|
||||
if calculate_entropy:
|
||||
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
indices = output_args["indices"]
|
||||
response_length = micro_batch["responses"].size(-1)
|
||||
batch_size, seqlen = input_ids.shape
|
||||
full_log_probs = pad_input(
|
||||
hidden_states=log_probs.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen,
|
||||
)
|
||||
full_log_probs = pad_input(
|
||||
hidden_states=log_probs.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen,
|
||||
)
|
||||
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
|
||||
# only return response part:
|
||||
if calculate_entropy:
|
||||
entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
# pad back to (bsz, seqlen)
|
||||
if calculate_entropy:
|
||||
full_entropy = pad_input(
|
||||
hidden_states=entropy_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen,
|
||||
)
|
||||
# only return response part:
|
||||
if calculate_entropy:
|
||||
entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
else: # not using rmpad and no ulysses sp
|
||||
response_length = tu.get_non_tensor_data(data=micro_batch, key="max_response_length", default=1024)
|
||||
if use_fused_kernels:
|
||||
log_probs = output.log_probs[:, -response_length - 1 : -1]
|
||||
entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
|
||||
else:
|
||||
logits = output.logits
|
||||
|
||||
logits.div_(temperature)
|
||||
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
|
||||
log_probs = logprobs_from_logits(logits, micro_batch.batch["responses"])
|
||||
|
||||
if calculate_entropy:
|
||||
if not self.engine_config.entropy_checkpointing:
|
||||
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
|
||||
entropy = verl_F.entropy_from_logits(logits)
|
||||
else:
|
||||
entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)
|
||||
|
||||
model_output = {"log_probs": log_probs}
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
cu_seqlens = input_ids.offsets()
|
||||
seq_lengths = cu_seqlens.diff()
|
||||
starts = torch.zeros_like(seq_lengths, dtype=torch.int64)
|
||||
logits = torch.nested.narrow(logits, 1, starts, seq_lengths, layout=torch.jagged)
|
||||
logits_rmpad = torch.cat([t for t in logits.unbind()])
|
||||
input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"]
|
||||
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
|
||||
# (bsz, j1) for each sample, length of each sample: [real_prompt_length + real_response_length]
|
||||
log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)
|
||||
if calculate_entropy:
|
||||
entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged)
|
||||
entropy_rmpad = torch.cat([t for t in entropy.unbind()])
|
||||
entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
|
||||
log_probs = logprobs_from_logits(logits, micro_batch["responses"])
|
||||
if calculate_entropy:
|
||||
entropy = entropy[:, -response_length - 1 : -1] # (bsz, response_length)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
model_output["log_probs"] = log_probs
|
||||
if calculate_entropy:
|
||||
model_output["entropy"] = entropy
|
||||
|
||||
|
@ -17,6 +17,7 @@ import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.seqlen_balancing import rearrange_micro_batches, restore_dynamic_batch
|
||||
|
||||
@ -65,6 +66,7 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
|
||||
"""
|
||||
|
||||
use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True)
|
||||
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
|
||||
# losses_reduced is a list of dict containing outputs for each micro-batch
|
||||
# reorder entropy and outputs. Return None for other pp ranks
|
||||
@ -87,7 +89,14 @@ def postprocess_batch_func(output_lst, indices, data: TensorDict):
|
||||
|
||||
# concat results from micro batches
|
||||
for key, val in model_output.items():
|
||||
model_output[key] = torch.cat(model_output[key], dim=0)
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
tensors = [tensor for nt in model_output[key] for tensor in nt.unbind()]
|
||||
model_output[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
|
||||
elif pad_mode == DatasetPadMode.LEFT_RIGHT:
|
||||
model_output[key] = torch.cat(model_output[key], dim=0)
|
||||
else:
|
||||
raise NotImplementedError(f"pad_mode {pad_mode} not implemented")
|
||||
|
||||
# reverse with dynamic bsz
|
||||
if use_dynamic_bsz:
|
||||
model_output[key] = restore_dynamic_batch(model_output[key], indices)
|
||||
|
@ -17,14 +17,34 @@ import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from verl.trainer.ppo.core_algos import agg_loss, compute_value_loss, get_policy_loss_fn, kl_penalty
|
||||
from verl.utils import tensordict_utils as tu
|
||||
from verl.utils.dataset.dataset_utils import DatasetPadMode
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.workers.config import ActorConfig, CriticConfig
|
||||
|
||||
|
||||
def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):
|
||||
log_prob = model_output["log_probs"] # [bsz, response_length]
|
||||
response_mask = data["response_mask"].to(bool)
|
||||
loss = -torch.mean(log_prob * response_mask)
|
||||
pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.LEFT_RIGHT)
|
||||
|
||||
log_prob = model_output["log_probs"]
|
||||
|
||||
if pad_mode == DatasetPadMode.NO_PADDING:
|
||||
# log_prob and loss mask are nested tensors of shape [bsz, j1]
|
||||
# for each sample, loss mask shape is [1, prompt_length + response_length]
|
||||
loss_mask = data["loss_mask"]
|
||||
|
||||
log_prob_flatten = log_prob.values()
|
||||
cu_seqlens = log_prob.offsets()
|
||||
loss_mask_flatten = loss_mask.values()
|
||||
|
||||
# left-shift the loss mask by one token to align with log_prob
|
||||
loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0)
|
||||
loss_mask_flatten[cu_seqlens[1:] - 1] = 0
|
||||
loss = -masked_mean(log_prob_flatten, loss_mask_flatten)
|
||||
else:
|
||||
response_mask = data["response_mask"].to(bool)
|
||||
loss = -masked_mean(log_prob, response_mask)
|
||||
|
||||
return loss, {"loss": loss.detach().item()}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user