[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 1 (#2733)

### What does this PR do?

- Add TensorDict utilities and tests to cover the current DataProto
functionalities.
- Add nested tensor example to remove padding throughout the system
- Add image example
- Upgrade tensordict to v0.10

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Chi Zhang
2025-09-09 14:47:32 +08:00
committed by GitHub
parent eaf20fff88
commit c4f4caf0cd
10 changed files with 776 additions and 16 deletions

View File

@ -10,7 +10,7 @@ peft>=0.15.2
pyarrow>=15.0.0 pyarrow>=15.0.0
pybind11 pybind11
pylatexenc pylatexenc
tensordict>=0.8.0,<=0.9.1,!=0.9.0 tensordict>=0.8.0,<=0.10.0,!=0.9.0
transformers==4.52.4 transformers==4.52.4
ray==2.46.0 ray==2.46.0
wandb wandb

View File

@ -14,7 +14,7 @@ pybind11
pylatexenc pylatexenc
pre-commit pre-commit
ray[default] ray[default]
tensordict>=0.8.0,<=0.9.1,!=0.9.0 tensordict>=0.8.0,<=0.10.0,!=0.9.0
torchdata torchdata
transformers transformers
# vllm==0.8.4 # vllm==0.8.4

View File

@ -12,7 +12,7 @@ pyarrow>=19.0.0
pybind11 pybind11
pylatexenc pylatexenc
ray[default]>=2.10 ray[default]>=2.10
tensordict>=0.8.0,<=0.9.1,!=0.9.0 tensordict>=0.8.0,<=0.10.0,!=0.9.0
torchdata torchdata
torchvision torchvision
transformers transformers

View File

@ -37,7 +37,7 @@ install_requires = [
"pylatexenc", "pylatexenc",
"ray[default]>=2.41.0", "ray[default]>=2.41.0",
"torchdata", "torchdata",
"tensordict>=0.8.0,<=0.9.1,!=0.9.0", "tensordict>=0.8.0,<=0.10.0,!=0.9.0",
"transformers", "transformers",
"wandb", "wandb",
"packaging>=20.0", "packaging>=20.0",
@ -49,9 +49,9 @@ PRIME_REQUIRES = ["pyext"]
GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"] GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"]
GPU_REQUIRES = ["liger-kernel", "flash-attn"] GPU_REQUIRES = ["liger-kernel", "flash-attn"]
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.9.1"] VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.10.0,!=0.9.0", "vllm>=0.7.3,<=0.9.1"]
SGLANG_REQUIRES = [ SGLANG_REQUIRES = [
"tensordict>=0.8.0,<=0.9.1,!=0.9.0", "tensordict>=0.8.0,<=0.10.0,!=0.9.0",
"sglang[srt,openai]==0.4.10.post2", "sglang[srt,openai]==0.4.10.post2",
"torch==2.7.1", "torch==2.7.1",
] ]

View File

@ -86,7 +86,11 @@ def main() -> None:
parser.add_argument( parser.add_argument(
"--allow-files", "--allow-files",
nargs="*", nargs="*",
default=["tests/test_protocol_on_cpu.py", "tests/test_base_config_on_cpu.py"], default=[
"tests/test_protocol_on_cpu.py",
"tests/test_base_config_on_cpu.py",
"tests/test_protocol_v2_on_cpu.py",
],
help="Extra top-level test folders that are exempt from the rule", help="Extra top-level test folders that are exempt from the rule",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -16,7 +16,9 @@ import random
import numpy as np import numpy as np
import pytest import pytest
import tensordict
import torch import torch
from packaging.version import parse as parse_version
from tensordict import TensorDict from tensordict import TensorDict
from verl import DataProto from verl import DataProto
@ -598,3 +600,17 @@ def test_dataproto_chunk_after_index():
selected = data[torch_int_mask] selected = data[torch_int_mask]
assert isinstance(selected.batch.batch_size, torch.Size) assert isinstance(selected.batch.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch.batch_size) assert all(isinstance(d, int) for d in selected.batch.batch_size)
@pytest.mark.skipif(
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
)
def test_to_tensordict():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
labels = ["a", "b", "c", "d", "e", "f"]
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})
output = data.to_tensordict()
assert torch.all(torch.eq(output["obs"], obs)).item()
assert output["labels"] == labels
assert output["name"] == "abdce"

View File

@ -0,0 +1,525 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Replace DataProto with raw TensorDict
"""
import copy
import random
import numpy as np
import pytest
import torch
from verl.utils import tensordict_utils as tu
def test_union_tensor_dict():
obs = torch.randn(100, 10)
meta_info1 = {"top_p": 0.8}
meta_info2 = {"top_p": 0.9}
data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
data2 = {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100), "data_sources": ["gsm8k"] * 100}
data_with_copied_obs = {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}
data1 = tu.get_tensordict(tensor_dict=data1)
data2 = tu.get_tensordict(tensor_dict=data2)
data_with_copied_obs = tu.get_tensordict(data_with_copied_obs)
tu.union_tensor_dict(data1, data2)
with pytest.raises(AssertionError):
# conflict in tensor values
tu.union_tensor_dict(data1, data_with_copied_obs)
data1 = tu.assign_non_tensor_dict(data1, meta_info1)
tu.union_tensor_dict(data1, data2) # works ok
data2 = tu.assign_non_tensor_dict(data2, meta_info2)
with pytest.raises(AssertionError):
# conflict in NonTensorData
tu.union_tensor_dict(data1, data2)
data1.pop("top_p")
data2.pop("top_p")
data2["data_sources"][0] = "math"
with pytest.raises(AssertionError):
# conflict in NonTensorData
tu.union_tensor_dict(data1, data2)
def test_tensor_dict_constructor():
obs = torch.ones(100, 10)
act = torch.zeros(100, 10, 3)
data_source = ["gsm8k"] * 100
non_tensor_dict = {"name": "abdce"}
data = tu.get_tensordict(
tensor_dict={"obs": obs, "act": act, "data_source": data_source}, non_tensor_dict=non_tensor_dict
)
assert data.batch_size == torch.Size([100])
# test slicing
assert torch.all(torch.eq(data[0]["obs"], torch.ones(10))).item()
assert torch.all(torch.eq(data[0]["act"], torch.zeros(10, 3))).item()
assert data[0]["data_source"] == "gsm8k"
assert torch.all(torch.eq(data[0:2]["obs"], torch.ones(2, 10))).item()
assert torch.all(torch.eq(data[0:2]["act"], torch.zeros(2, 10, 3))).item()
assert data[0:2]["data_source"] == ["gsm8k"] * 2
# test non tensor data
assert data["name"] == "abdce"
def test_tensordict_with_images():
# each sample contains a sequence with multiple images of different sizes
vocab_size = 128
a = torch.randint(low=0, high=vocab_size, size=(11,))
b = torch.randint(low=0, high=vocab_size, size=(13,))
input_ids = [a, b]
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
# must be numpy
# TODO(vermouth1992). We may use nested tensor too. But this requires nested over nested
a_images = [
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
]
b_images = [
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
]
images = [a_images, b_images]
data = tu.get_tensordict({"input_ids": input_ids, "images": images})
assert np.all(np.equal(data[0]["images"][0], a_images[0]))
assert torch.all(torch.eq(data[0]["input_ids"], a))
def test_tensordict_with_packing():
vocab_size = 128
a = torch.randint(low=0, high=vocab_size, size=(11,))
b = torch.randint(low=0, high=vocab_size, size=(13,))
input_ids = [a, b]
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
data = tu.get_tensordict({"input_ids": input_ids})
# test cu_seqlens
cu_seqlens = torch.tensor([0, 11, 24])
assert torch.all(torch.eq(cu_seqlens, data["input_ids"].offsets()))
# test index
assert torch.all(torch.eq(data["input_ids"][0], a))
assert torch.all(torch.eq(data["input_ids"][1], b))
assert torch.all(torch.eq(data[0]["input_ids"], a))
assert torch.all(torch.eq(data[1]["input_ids"], b))
data_lst = data.chunk(2)
assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))
def test_tensordict_eq():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
data = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
obs = torch.tensor([1, 2, 3, 4, 5, 6])
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
data1 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
tu.assert_tensordict_eq(data, data1)
data2 = copy.deepcopy(data1)
data2["obs"][0] += 1
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data, data2)
data2 = copy.deepcopy(data1)
data2["data_sources"][0] = "math"
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data, data2)
data2 = copy.deepcopy(data1)
data2["train_sample_kwargs"]["top_p"] = 0.9
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data, data2)
def test_tensor_dict_make_iterator():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
dataset = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
dataloader = tu.make_iterator(
dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={"shuffle": False, "drop_last": False}
)
expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]]
i = 0
for d in dataloader:
tu.assert_tensordict_eq(d, expected_tensor_dict[i])
i += 1
data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
data_list_1 = []
for data in data_iter_1:
data_list_1.append(data)
data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
data_list_2 = []
for data in data_iter_2:
data_list_2.append(data)
for data1, data2 in zip(data_list_1, data_list_2, strict=True):
tu.assert_tensordict_eq(data1, data2)
def test_reorder():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
labels = ["a", "b", "c", "d", "e", "f"]
non_tensor_dict = {"name": "abdce"}
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict=non_tensor_dict)
data = data[torch.tensor([3, 4, 2, 0, 1, 5])]
assert torch.all(torch.eq(data["obs"], torch.tensor([4, 5, 3, 1, 2, 6])))
assert np.all(data["labels"] == np.array(["d", "e", "c", "a", "b", "f"]))
assert data["name"] == "abdce"
def test_chunk_concat():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
labels = ["a", "b", "c", "d", "e", "f"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"name": "abcde"})
data_split = data.tensor_split(indices_or_sections=5, dim=0)
expected_idx_lst = [[0, 1], [2], [3], [4], [5]]
for d, expected_idx in zip(data_split, expected_idx_lst, strict=False):
tu.assert_tensordict_eq(d, data[expected_idx])
data_split = data.chunk(2)
assert len(data_split) == 2
assert torch.all(torch.eq(data_split[0]["obs"], torch.tensor([1, 2, 3])))
assert np.all(data_split[0]["labels"] == np.array(["a", "b", "c"]))
assert data_split[0]["name"] == "abcde"
assert torch.all(torch.eq(data_split[1]["obs"], torch.tensor([4, 5, 6])))
assert np.all(data_split[1]["labels"] == np.array(["d", "e", "f"]))
assert data_split[1]["name"] == "abcde"
concat_data = torch.cat(data_split, dim=0)
assert torch.all(torch.eq(concat_data["obs"], data["obs"]))
assert np.all(concat_data["labels"] == data["labels"])
assert concat_data["name"] == data["name"]
def test_pop():
obs = torch.randn(100, 10)
act = torch.randn(100, 3)
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
poped_dataset = tu.pop(dataset, keys=["obs", "2"])
assert poped_dataset.batch_size[0] == 100
assert poped_dataset.keys() == {"obs", "2"}
assert dataset.keys() == {"act", "1"}
def test_repeat():
# Create a DataProto object with some batch and non-tensor data
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
# Test interleave=True
repeated_data_interleave = data.repeat_interleave(repeats=2)
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])
expected_labels_interleave = ["a", "a", "b", "b", "c", "c"]
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
assert repeated_data_interleave["labels"] == expected_labels_interleave
assert repeated_data_interleave["info"] == "test_info"
# Test interleave=False
repeated_data_no_interleave = data.repeat(2)
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])
expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"]
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
assert repeated_data_no_interleave["info"] == "test_info"
def test_dataproto_pad_unpad():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2)
assert pad_size == 1
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ["a", "b", "c", "a"]
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
assert padded_data["labels"] == expected_labels
assert padded_data["info"] == "test_info"
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data["obs"], obs))
assert unpadd_data["labels"] == labels
assert unpadd_data["info"] == "test_info"
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3)
assert pad_size == 0
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
expected_labels = ["a", "b", "c"]
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
assert padded_data["labels"] == expected_labels
assert padded_data["info"] == "test_info"
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data["obs"], obs))
assert unpadd_data["labels"] == labels
assert unpadd_data["info"] == "test_info"
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7)
assert pad_size == 4
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ["a", "b", "c", "a", "b", "c", "a"]
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
assert padded_data["labels"] == expected_labels
assert padded_data["info"] == "test_info"
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data["obs"], obs))
assert unpadd_data["labels"] == labels
assert unpadd_data["info"] == "test_info"
def test_torch_save_data_proto():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
filename = "test_data.pt"
torch.save(data, filename)
loaded_data = torch.load(filename, weights_only=False)
assert torch.all(torch.eq(loaded_data["obs"], data["obs"]))
assert loaded_data["labels"] == data["labels"]
assert loaded_data["info"] == data["info"]
import os
os.remove(filename)
def test_len():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = np.array(["a", "b", "c"], dtype=object)
data = tu.get_tensordict({"obs": obs, "labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
assert len(data) == 3
data = tu.get_tensordict({"labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
assert len(data) == 3
data_item = data[0]
assert len(data_item) == 0
data = tu.get_tensordict({}, non_tensor_dict={"info": "test_info"})
assert len(data) == 0
def test_dataproto_index():
data_len = 100
idx_num = 10
obs = torch.randn(data_len, 10)
labels = [random.choice(["abc", "cde"]) for _ in range(data_len)]
data = tu.get_tensordict({"obs": obs, "labels": labels})
labels_np = np.array(labels)
idx_np_int = np.random.randint(0, data_len, size=(idx_num,))
result_np_int = data[idx_np_int]
assert result_np_int.keys() == data.keys()
assert result_np_int["obs"].shape[0] == idx_num
assert len(result_np_int["labels"]) == idx_num
assert np.array_equal(result_np_int["obs"].cpu().numpy(), obs[idx_np_int].numpy())
assert np.array_equal(result_np_int["labels"], labels_np[idx_np_int])
idx_torch_int = torch.randint(0, data_len, size=(idx_num,))
result_torch_int = data[idx_torch_int]
assert result_torch_int.keys() == data.keys()
assert result_torch_int["obs"].shape[0] == idx_num
assert len(result_torch_int["labels"]) == idx_num
assert np.array_equal(result_torch_int["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())
assert np.array_equal(result_torch_int["labels"], labels_np[idx_torch_int.cpu().numpy()])
idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]
result_list_int = data[idx_list_int]
assert result_list_int.keys() == data.keys()
assert result_list_int["obs"].shape[0] == idx_num
assert len(result_list_int["labels"]) == idx_num
assert np.array_equal(result_list_int["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy())
assert np.array_equal(result_list_int["labels"], labels_np[idx_list_int])
# idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)
# result_np_bool = data[idx_np_bool]
# assert result_np_bool.keys() == data.keys()
# assert result_np_bool["obs"].shape[0] == idx_np_bool.sum()
# assert len(result_np_bool["labels"]) == idx_np_bool.sum()
# assert np.array_equal(result_np_bool["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())
# assert np.array_equal(result_np_bool["labels"], labels_np[idx_np_bool])
idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)
result_torch_bool = data[idx_torch_bool]
assert result_torch_bool.keys() == data.keys()
assert result_torch_bool["obs"].shape[0] == idx_torch_bool.sum().item()
assert len(result_torch_bool["labels"]) == idx_torch_bool.sum().item()
assert np.array_equal(result_torch_bool["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())
assert np.array_equal(result_torch_bool["labels"], labels_np[idx_torch_bool])
# idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]
# result_list_bool = data[idx_list_bool]
# assert result_list_bool.keys() == data.keys()
# assert result_list_bool["obs"].shape[0] == sum(idx_list_bool)
# assert len(result_list_bool["labels"]) == sum(idx_list_bool)
# assert np.array_equal(result_list_bool["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())
# assert np.array_equal(result_list_bool["labels"], labels_np[idx_list_bool])
def test_select():
obs = torch.randn(100, 10)
act = torch.randn(100, 3)
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
subset = dataset.select("obs", "2")
assert torch.all(torch.eq(subset["obs"], dataset["obs"]))
assert subset["2"] == dataset["2"]
assert "act" not in subset.keys()
assert "1" not in subset.keys()
def test_dataproto_no_batch():
labels = ["a", "b", "c"]
data = tu.get_tensordict(tensor_dict={"labels": labels}, non_tensor_dict={"info": "test_info"})
selected = data.select("labels")
assert selected["labels"] == labels
pop_data = tu.pop(data, keys=["labels"])
assert pop_data["labels"] == labels
assert "labels" not in data
def test_sample_level_repeat():
# Create a DataProto object with some batch and non-tensor data
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
# list
repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2]))
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])
expected_labels_interleave = ["a", "a", "a", "b", "c", "c"]
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
assert repeated_data_interleave["labels"] == expected_labels_interleave
assert repeated_data_interleave["info"] == "test_info"
# torch.tensor
repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3]))
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])
expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"]
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
assert repeated_data_no_interleave["info"] == "test_info"
def test_dataproto_chunk_after_index():
data_len = 4
obs = torch.randn(data_len, 4)
labels = [f"label_{i}" for i in range(data_len)]
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"name": "abc"})
# Test with boolean numpy array
bool_mask = torch.tensor([True, False, True, False])
selected = data[bool_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size) # int or List[int]
# Test with integer numpy array
int_mask = torch.tensor([0, 2])
selected = data[int_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with boolean list
list_mask = [True, False, True, False]
selected = data[list_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with list
list_mask = [0, 2]
selected = data[list_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with torch tensor (bool)
torch_bool_mask = torch.tensor([True, False, True, False])
selected = data[torch_bool_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with torch tensor (int)
torch_int_mask = torch.tensor([0, 2])
selected = data[torch_int_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)

View File

@ -74,16 +74,19 @@ if is_npu_available:
# for third-party devices such as NPUs. This patch fixes this issue, and the relevant # for third-party devices such as NPUs. This patch fixes this issue, and the relevant
# modifications can be removed once the fix is merged into tensordict. # modifications can be removed once the fix is merged into tensordict.
from tensordict.base import TensorDictBase import tensordict
def _sync_all_patch(self): if parse_version(tensordict.__version__) < parse_version("0.10.0"):
from torch._utils import _get_available_device_type, _get_device_module from tensordict.base import TensorDictBase
device_type = _get_available_device_type() def _sync_all_patch(self):
if device_type is None: from torch._utils import _get_available_device_type, _get_device_module
return
device_module = _get_device_module(device_type) device_type = _get_available_device_type()
device_module.synchronize() if device_type is None:
return
TensorDictBase._sync_all = _sync_all_patch device_module = _get_device_module(device_type)
device_module.synchronize()
TensorDictBase._sync_all = _sync_all_patch

View File

@ -31,6 +31,7 @@ import tensordict
import torch import torch
import torch.distributed import torch.distributed
from packaging import version from packaging import version
from packaging.version import parse as parse_version
from tensordict import TensorDict from tensordict import TensorDict
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -42,6 +43,8 @@ __all__ = ["DataProto", "union_tensor_dict"]
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
tensordict.set_lazy_legacy(False).set() tensordict.set_lazy_legacy(False).set()
if parse_version(tensordict.__version__) < parse_version("0.10.0"):
tensordict.set_list_to_stack(True).set()
class _DataProtoConfigMeta(type): class _DataProtoConfigMeta(type):
@ -964,6 +967,29 @@ class DataProto:
meta_info=self.meta_info, meta_info=self.meta_info,
) )
def to_tensordict(self) -> TensorDict:
"""Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10
Returns:
"""
assert parse_version(tensordict.__version__) >= parse_version("0.10"), (
"Convert DataProto to TensorDict at least requires tensordict version 0.10"
)
tensor_batch = self.batch.to_dict()
non_tensor_batch = self.non_tensor_batch
from verl.utils import tensordict_utils as tu
common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys())
assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}"
for key, val in non_tensor_batch.items():
assert isinstance(val, np.ndarray)
tensor_batch[key] = val.tolist()
output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info)
return output
def get_data_info(self) -> str: def get_data_info(self) -> str:
"""Return formatted information about stored data with nested type details. """Return formatted information about stored data with nested type details.

View File

@ -0,0 +1,186 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterator
import torch
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorData, NonTensorStack
def assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict):
for key, val in non_tensor_dict.items():
assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)
return tensor_dict
def assign_non_tensor_data(tensor_dict: TensorDict, key, val):
tensor_dict[key] = NonTensorData(val)
def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
"""
Args:
data_dict:
meta_info:
Returns:
"""
if non_tensor_dict is None:
non_tensor_dict = {}
batch_size = None
for key, val in tensor_dict.items():
if isinstance(val, list):
for v in val:
assert not isinstance(v, torch.Tensor), (
"Passing a list makes the data NonTensorStack, "
"which doesn't support torch.Tensor. Please convert to numpy first"
)
assert isinstance(val, torch.Tensor | list)
if batch_size is None:
batch_size = len(val)
else:
assert len(val) == batch_size
if batch_size is None:
batch_size = []
else:
batch_size = [batch_size]
for key, val in non_tensor_dict.items():
assert key not in tensor_dict
tensor_dict[key] = NonTensorData(val)
return TensorDict(source=tensor_dict, batch_size=batch_size)
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
"""Union two tensordicts."""
assert tensor_dict1.batch_size == tensor_dict2.batch_size, (
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
)
for key in tensor_dict2.keys():
if key not in tensor_dict1.keys():
tensor_dict1[key] = tensor_dict2[key]
else:
if isinstance(tensor_dict2[key], torch.Tensor):
assert tensor_dict1[key].equal(tensor_dict2[key]), (
f"{key} in tensor_dict1 and tensor_dict2 are not the same object"
)
else:
# non-tensor
assert tensor_dict1[key] == tensor_dict2[key], (
f"{key} in tensor_dict1 and tensor_dict2 are not the same object"
)
return tensor_dict1
def make_iterator(tensordict: TensorDict, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
from torch.utils.data import DataLoader
assert tensordict.batch_size[0] % mini_batch_size == 0, f"{tensordict.batch_size[0]} % {mini_batch_size} != 0"
# we can directly create a dataloader from TensorDict
if dataloader_kwargs is None:
dataloader_kwargs = {}
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = None
assert isinstance(dataloader_kwargs, dict)
train_dataloader = DataLoader(
dataset=tensordict, batch_size=mini_batch_size, collate_fn=lambda x: x, generator=generator, **dataloader_kwargs
)
def get_data():
for _ in range(epochs):
yield from train_dataloader
return iter(get_data())
def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):
assert set(tensordict1.keys()) == set(tensordict2.keys())
for key in tensordict1.keys():
val = tensordict1[key]
val2 = tensordict2[key]
assert type(val) is type(val2), f"The type of {key} must be the same. Got {type(val)} vs {type(val2)}"
if isinstance(val, torch.Tensor):
assert torch.all(torch.eq(val, val2)).item()
else:
assert val == val2
def pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict:
tensor_output = {}
non_tensor_output = {}
for key in keys:
output = tensordict.get(key)
if isinstance(output, torch.Tensor):
tensor_output[key] = tensordict.pop(key)
elif isinstance(output, NonTensorStack):
tensor_output[key] = tensordict.pop(key).tolist()
else:
assert isinstance(output, NonTensorData)
non_tensor_output[key] = tensordict.pop(key)
return get_tensordict(tensor_output, non_tensor_output)
def pad_to_divisor(data: TensorDict, size_divisor: int):
"""Pad a TensorDict to size divisible by size_divisor
Args:
size_divisor (int): size divisor
Returns:
data: (TensorDict): the padded TensorDict
pad_size (int)
"""
assert isinstance(data, TensorDict), "data must be a TensorDict"
if len(data) % size_divisor != 0:
pad_size = size_divisor - len(data) % size_divisor
padding_protos = []
remaining_pad = pad_size
while remaining_pad > 0:
take_size = min(remaining_pad, len(data))
padding_protos.append(data[:take_size])
remaining_pad -= take_size
data_padded = torch.cat([data] + padding_protos)
else:
if len(data) == 0:
logging.warning("padding a DataProto with no item, no changed made")
pad_size = 0
data_padded = data
return data_padded, pad_size
def unpad(data: TensorDict, pad_size):
"""Unpad the data proto with pad_size. i.e. `data[:-pad_size]`"""
if pad_size != 0:
data = data[:-pad_size]
return data