mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. This PR continues the work started in PR #3567 by deprecating and removing the left_right padding mode 1. Implement no-padding mode for Megatron engine using nested tensors in sft trainer 2. Deprecating left_right padding mode for FSDP/Megatron engine 3. Introduces a transformation layer within Actor/Critic workers, see more [here](https://github.com/volcengine/verl/blob/main/docs/workers/model_engine.rst) - **Input Format**: Actor/Critic workers continue to receive data in left_rightpadded format. - **Transformation**: This layer dynamically converts left_rightpadded data into the no-padding format using nested tensors. - **Engine Format**: FSDP and Megatron engines now operate exclusively using the no-padding data format by default. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
203 lines
8.8 KiB
Python
203 lines
8.8 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Test the MultiTurnSFTDataset implementation
|
|
"""
|
|
|
|
import os
|
|
|
|
import pandas as pd
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
|
|
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
|
|
|
|
|
|
def test_multiturn_sft_dataset():
|
|
print("Starting test...")
|
|
# Create a temporary parquet file with test data
|
|
test_data = {
|
|
"messages": [
|
|
[
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "2+2 equals 4."},
|
|
{"role": "user", "content": "And what is 4+4?"},
|
|
{"role": "assistant", "content": "4+4 equals 8."},
|
|
],
|
|
[
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Tell me a joke."},
|
|
{"role": "assistant", "content": "Why did the chicken cross the road?"},
|
|
{"role": "user", "content": "Why?"},
|
|
{"role": "assistant", "content": "To get to the other side!"},
|
|
],
|
|
]
|
|
}
|
|
|
|
# Create test directory if it doesn't exist
|
|
os.makedirs("test_data", exist_ok=True)
|
|
test_file = "test_data/test.parquet"
|
|
|
|
# Save test data to parquet
|
|
df = pd.DataFrame(test_data)
|
|
df.to_parquet(test_file)
|
|
|
|
# Initialize tokenizer and dataset
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct")
|
|
config = {"max_length": 512, "truncation": "error", "multiturn": {"messages_key": "messages"}}
|
|
dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
|
|
|
|
# Test 1: Dataset Length
|
|
assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}"
|
|
|
|
# Get items for testing
|
|
item0 = dataset[0] # Math conversation
|
|
item1 = dataset[1] # Joke conversation
|
|
|
|
# Test 2: Required Keys and Types
|
|
required_keys = ["input_ids", "attention_mask", "position_ids", "loss_mask"]
|
|
for key in required_keys:
|
|
assert key in item0, f"Missing key {key} in dataset item"
|
|
assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}"
|
|
assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}"
|
|
|
|
# Test 3: Shape Consistency
|
|
assert item0["loss_mask"].shape == item0["input_ids"].shape, "Loss mask shape doesn't match input_ids shape"
|
|
assert item0["attention_mask"].shape == item0["input_ids"].shape, (
|
|
"Attention mask shape doesn't match input_ids shape"
|
|
)
|
|
assert item0["position_ids"].shape == item0["input_ids"].shape, "Position IDs shape doesn't match input_ids shape"
|
|
|
|
# Test 4: Loss Mask Pattern - Math Conversation
|
|
loss_mask0 = item0["loss_mask"]
|
|
input_ids0 = item0["input_ids"]
|
|
|
|
# Find assistant response positions
|
|
assistant_positions0 = torch.where(loss_mask0 == 1)[0]
|
|
assert len(assistant_positions0) > 0, "No assistant positions found in loss mask"
|
|
|
|
# Decode and verify assistant responses
|
|
assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1])
|
|
print(f"Math conversation assistant text: {assistant_text0}")
|
|
assert "2+2 equals 4" in assistant_text0, "First assistant response not found"
|
|
assert "4+4 equals 8" in assistant_text0, "Second assistant response not found"
|
|
|
|
# Test 5: Loss Mask Pattern - Joke Conversation
|
|
loss_mask1 = item1["loss_mask"]
|
|
input_ids1 = item1["input_ids"]
|
|
|
|
# Find assistant response positions
|
|
assistant_positions1 = torch.where(loss_mask1 == 1)[0]
|
|
assert len(assistant_positions1) > 0, "No assistant positions found in loss mask"
|
|
|
|
# Decode and verify assistant responses
|
|
assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1])
|
|
print(f"Joke conversation assistant text: {assistant_text1}")
|
|
assert "chicken cross the road" in assistant_text1, "First assistant response not found"
|
|
assert "other side" in assistant_text1, "Second assistant response not found"
|
|
|
|
# Test 6: Attention Mask Pattern
|
|
attention_mask0 = item0["attention_mask"]
|
|
sequence_length = torch.sum(attention_mask0)
|
|
assert sequence_length > 0, "No tokens marked as attended in attention mask"
|
|
assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern"
|
|
if sequence_length < len(attention_mask0):
|
|
assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked"
|
|
|
|
# Test 7: Position IDs Pattern
|
|
position_ids0 = item0["position_ids"]
|
|
assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), (
|
|
"Position IDs not sequential for non-padded tokens"
|
|
)
|
|
if sequence_length < len(position_ids0):
|
|
assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero"
|
|
|
|
# Test 8: Verify loss mask for assistant responses
|
|
# Get the full conversation text
|
|
full_text = tokenizer.decode(input_ids0)
|
|
print(f"\nFull conversation text:\n{full_text}")
|
|
|
|
# Get the assistant responses
|
|
assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1])
|
|
print(f"\nAssistant responses (from loss mask):\n{assistant_text}")
|
|
|
|
# Verify that loss mask is set for all assistant responses
|
|
for msg in test_data["messages"][0]: # First conversation
|
|
if msg["role"] == "assistant":
|
|
# The content should appear in the masked text
|
|
assert msg["content"] in assistant_text, f"Assistant message '{msg['content']}' not found in masked text"
|
|
|
|
# The content should NOT appear in the non-masked text
|
|
non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])
|
|
assert msg["content"] not in non_assistant_text, (
|
|
f"Assistant message '{msg['content']}' found in non-assistant text"
|
|
)
|
|
|
|
# Test 9: Verify non-assistant parts have loss_mask=0
|
|
# Get non-assistant text
|
|
non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])
|
|
print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}")
|
|
|
|
# Verify that system and user messages are in the non-assistant text
|
|
for msg in test_data["messages"][0]: # First conversation
|
|
if msg["role"] in ["system", "user"]:
|
|
assert msg["content"] in non_assistant_text, (
|
|
f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text"
|
|
)
|
|
|
|
# And verify they're NOT in the assistant text
|
|
assert msg["content"] not in assistant_text, (
|
|
f"{msg['role'].title()} message '{msg['content']}' found in assistant text"
|
|
)
|
|
|
|
# Test 10: Verify padding behavior
|
|
padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}}
|
|
small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config)
|
|
padded_item = small_dataset[0]
|
|
|
|
# Get actual sequence length (before padding)
|
|
actual_length = torch.sum(padded_item["attention_mask"])
|
|
|
|
# Verify padding tokens
|
|
assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), (
|
|
"Padding tokens not set correctly"
|
|
)
|
|
assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding"
|
|
assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding"
|
|
|
|
# test no-padding
|
|
config = {
|
|
"max_length": 512,
|
|
"truncation": "error",
|
|
"multiturn": {"messages_key": "messages"},
|
|
"pad_mode": "no_padding",
|
|
}
|
|
dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
|
|
|
|
item0 = dataset[0]
|
|
|
|
# Verify that the output contains expected keys for no-padding mode
|
|
required_keys = ["input_ids", "position_ids", "loss_mask"]
|
|
for key in required_keys:
|
|
assert key in item0, f"Missing key {key} in no-padding mode dataset item"
|
|
assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key} in no-padding mode"
|
|
|
|
# make sure assistant_text matches with expected
|
|
assistant_text = tokenizer.decode(item0["input_ids"][item0["loss_mask"] == 1])
|
|
assert assistant_text == "2+2 equals 4.<|im_end|>\n4+4 equals 8.<|im_end|>\n"
|
|
|
|
print("All tests passed!")
|
|
print("Starting test...")
|