From c4f4caf0cd7186a09173f312ab85bda2519d810d Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Tue, 9 Sep 2025 14:47:32 +0800 Subject: [PATCH] [misc] feat: prototype deprecate DataProto and replace with Tensordict: part 1 (#2733) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? - Add TensorDict utilities and tests to cover the current DataProto functionalities. - Add nested tensor example to remove padding throughout the system - Add image example - Upgrade tensordict to v0.10 ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- requirements-npu.txt | 2 +- requirements.txt | 2 +- requirements_sglang.txt | 2 +- setup.py | 6 +- tests/special_sanity/validate_structure.py | 6 +- tests/test_protocol_on_cpu.py | 16 + tests/test_protocol_v2_on_cpu.py | 525 +++++++++++++++++++++ verl/__init__.py | 21 +- verl/protocol.py | 26 + verl/utils/tensordict_utils.py | 186 ++++++++ 10 files changed, 776 insertions(+), 16 deletions(-) create mode 100644 tests/test_protocol_v2_on_cpu.py create mode 100644 verl/utils/tensordict_utils.py diff --git a/requirements-npu.txt b/requirements-npu.txt index 958bb49b3..ae9ed1161 100644 --- a/requirements-npu.txt +++ b/requirements-npu.txt @@ -10,7 +10,7 @@ peft>=0.15.2 pyarrow>=15.0.0 pybind11 pylatexenc -tensordict>=0.8.0,<=0.9.1,!=0.9.0 +tensordict>=0.8.0,<=0.10.0,!=0.9.0 transformers==4.52.4 ray==2.46.0 wandb diff --git a/requirements.txt b/requirements.txt index 162022343..64dc7f585 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ pybind11 pylatexenc pre-commit ray[default] -tensordict>=0.8.0,<=0.9.1,!=0.9.0 +tensordict>=0.8.0,<=0.10.0,!=0.9.0 torchdata transformers # vllm==0.8.4 diff --git a/requirements_sglang.txt b/requirements_sglang.txt index c366ace43..34e23f4cd 100644 --- a/requirements_sglang.txt +++ b/requirements_sglang.txt @@ -12,7 +12,7 @@ pyarrow>=19.0.0 pybind11 pylatexenc ray[default]>=2.10 -tensordict>=0.8.0,<=0.9.1,!=0.9.0 +tensordict>=0.8.0,<=0.10.0,!=0.9.0 torchdata torchvision transformers diff --git a/setup.py b/setup.py index 5c10c1547..780d622e1 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ install_requires = [ "pylatexenc", "ray[default]>=2.41.0", "torchdata", - "tensordict>=0.8.0,<=0.9.1,!=0.9.0", + "tensordict>=0.8.0,<=0.10.0,!=0.9.0", "transformers", "wandb", "packaging>=20.0", @@ -49,9 +49,9 @@ PRIME_REQUIRES = ["pyext"] GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"] GPU_REQUIRES = ["liger-kernel", "flash-attn"] MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency -VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.9.1"] +VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.10.0,!=0.9.0", "vllm>=0.7.3,<=0.9.1"] SGLANG_REQUIRES = [ - "tensordict>=0.8.0,<=0.9.1,!=0.9.0", + "tensordict>=0.8.0,<=0.10.0,!=0.9.0", "sglang[srt,openai]==0.4.10.post2", "torch==2.7.1", ] diff --git a/tests/special_sanity/validate_structure.py b/tests/special_sanity/validate_structure.py index a5390b15a..56136b206 100644 --- a/tests/special_sanity/validate_structure.py +++ b/tests/special_sanity/validate_structure.py @@ -86,7 +86,11 @@ def main() -> None: parser.add_argument( "--allow-files", nargs="*", - default=["tests/test_protocol_on_cpu.py", "tests/test_base_config_on_cpu.py"], + default=[ + "tests/test_protocol_on_cpu.py", + "tests/test_base_config_on_cpu.py", + "tests/test_protocol_v2_on_cpu.py", + ], help="Extra top-level test folders that are exempt from the rule", ) args = parser.parse_args() diff --git a/tests/test_protocol_on_cpu.py b/tests/test_protocol_on_cpu.py index 42c5ed18c..3621a527f 100644 --- a/tests/test_protocol_on_cpu.py +++ b/tests/test_protocol_on_cpu.py @@ -16,7 +16,9 @@ import random import numpy as np import pytest +import tensordict import torch +from packaging.version import parse as parse_version from tensordict import TensorDict from verl import DataProto @@ -598,3 +600,17 @@ def test_dataproto_chunk_after_index(): selected = data[torch_int_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) + + +@pytest.mark.skipif( + parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" +) +def test_to_tensordict(): + obs = torch.tensor([1, 2, 3, 4, 5, 6]) + labels = ["a", "b", "c", "d", "e", "f"] + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) + output = data.to_tensordict() + + assert torch.all(torch.eq(output["obs"], obs)).item() + assert output["labels"] == labels + assert output["name"] == "abdce" diff --git a/tests/test_protocol_v2_on_cpu.py b/tests/test_protocol_v2_on_cpu.py new file mode 100644 index 000000000..57a741e54 --- /dev/null +++ b/tests/test_protocol_v2_on_cpu.py @@ -0,0 +1,525 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Replace DataProto with raw TensorDict +""" + +import copy +import random + +import numpy as np +import pytest +import torch + +from verl.utils import tensordict_utils as tu + + +def test_union_tensor_dict(): + obs = torch.randn(100, 10) + + meta_info1 = {"top_p": 0.8} + meta_info2 = {"top_p": 0.9} + data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100} + data2 = {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100), "data_sources": ["gsm8k"] * 100} + + data_with_copied_obs = {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)} + + data1 = tu.get_tensordict(tensor_dict=data1) + data2 = tu.get_tensordict(tensor_dict=data2) + data_with_copied_obs = tu.get_tensordict(data_with_copied_obs) + + tu.union_tensor_dict(data1, data2) + with pytest.raises(AssertionError): + # conflict in tensor values + tu.union_tensor_dict(data1, data_with_copied_obs) + + data1 = tu.assign_non_tensor_dict(data1, meta_info1) + tu.union_tensor_dict(data1, data2) # works ok + + data2 = tu.assign_non_tensor_dict(data2, meta_info2) + + with pytest.raises(AssertionError): + # conflict in NonTensorData + tu.union_tensor_dict(data1, data2) + + data1.pop("top_p") + data2.pop("top_p") + + data2["data_sources"][0] = "math" + with pytest.raises(AssertionError): + # conflict in NonTensorData + tu.union_tensor_dict(data1, data2) + + +def test_tensor_dict_constructor(): + obs = torch.ones(100, 10) + act = torch.zeros(100, 10, 3) + data_source = ["gsm8k"] * 100 + non_tensor_dict = {"name": "abdce"} + + data = tu.get_tensordict( + tensor_dict={"obs": obs, "act": act, "data_source": data_source}, non_tensor_dict=non_tensor_dict + ) + + assert data.batch_size == torch.Size([100]) + + # test slicing + assert torch.all(torch.eq(data[0]["obs"], torch.ones(10))).item() + assert torch.all(torch.eq(data[0]["act"], torch.zeros(10, 3))).item() + assert data[0]["data_source"] == "gsm8k" + + assert torch.all(torch.eq(data[0:2]["obs"], torch.ones(2, 10))).item() + assert torch.all(torch.eq(data[0:2]["act"], torch.zeros(2, 10, 3))).item() + assert data[0:2]["data_source"] == ["gsm8k"] * 2 + + # test non tensor data + assert data["name"] == "abdce" + + +def test_tensordict_with_images(): + # each sample contains a sequence with multiple images of different sizes + vocab_size = 128 + a = torch.randint(low=0, high=vocab_size, size=(11,)) + b = torch.randint(low=0, high=vocab_size, size=(13,)) + input_ids = [a, b] + input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged) + + # must be numpy + # TODO(vermouth1992). We may use nested tensor too. But this requires nested over nested + a_images = [ + torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(), + torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(), + ] + b_images = [ + torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(), + torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(), + torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(), + ] + + images = [a_images, b_images] + + data = tu.get_tensordict({"input_ids": input_ids, "images": images}) + + assert np.all(np.equal(data[0]["images"][0], a_images[0])) + assert torch.all(torch.eq(data[0]["input_ids"], a)) + + +def test_tensordict_with_packing(): + vocab_size = 128 + a = torch.randint(low=0, high=vocab_size, size=(11,)) + b = torch.randint(low=0, high=vocab_size, size=(13,)) + input_ids = [a, b] + input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged) + + data = tu.get_tensordict({"input_ids": input_ids}) + + # test cu_seqlens + cu_seqlens = torch.tensor([0, 11, 24]) + assert torch.all(torch.eq(cu_seqlens, data["input_ids"].offsets())) + + # test index + assert torch.all(torch.eq(data["input_ids"][0], a)) + assert torch.all(torch.eq(data["input_ids"][1], b)) + + assert torch.all(torch.eq(data[0]["input_ids"], a)) + assert torch.all(torch.eq(data[1]["input_ids"], b)) + + data_lst = data.chunk(2) + + assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a)) + assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b)) + + +def test_tensordict_eq(): + obs = torch.tensor([1, 2, 3, 4, 5, 6]) + data_sources = ["abc", "def", "abc", "def", "pol", "klj"] + non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}} + data = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) + + obs = torch.tensor([1, 2, 3, 4, 5, 6]) + data_sources = ["abc", "def", "abc", "def", "pol", "klj"] + non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}} + data1 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) + + tu.assert_tensordict_eq(data, data1) + + data2 = copy.deepcopy(data1) + data2["obs"][0] += 1 + + with pytest.raises(AssertionError): + tu.assert_tensordict_eq(data, data2) + + data2 = copy.deepcopy(data1) + data2["data_sources"][0] = "math" + + with pytest.raises(AssertionError): + tu.assert_tensordict_eq(data, data2) + + data2 = copy.deepcopy(data1) + data2["train_sample_kwargs"]["top_p"] = 0.9 + + with pytest.raises(AssertionError): + tu.assert_tensordict_eq(data, data2) + + +def test_tensor_dict_make_iterator(): + obs = torch.tensor([1, 2, 3, 4, 5, 6]) + data_sources = ["abc", "def", "abc", "def", "pol", "klj"] + non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}} + dataset = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) + + dataloader = tu.make_iterator( + dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={"shuffle": False, "drop_last": False} + ) + + expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]] + + i = 0 + + for d in dataloader: + tu.assert_tensordict_eq(d, expected_tensor_dict[i]) + i += 1 + + data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True}) + data_list_1 = [] + for data in data_iter_1: + data_list_1.append(data) + + data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True}) + data_list_2 = [] + for data in data_iter_2: + data_list_2.append(data) + + for data1, data2 in zip(data_list_1, data_list_2, strict=True): + tu.assert_tensordict_eq(data1, data2) + + +def test_reorder(): + obs = torch.tensor([1, 2, 3, 4, 5, 6]) + labels = ["a", "b", "c", "d", "e", "f"] + non_tensor_dict = {"name": "abdce"} + + data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict=non_tensor_dict) + data = data[torch.tensor([3, 4, 2, 0, 1, 5])] + + assert torch.all(torch.eq(data["obs"], torch.tensor([4, 5, 3, 1, 2, 6]))) + assert np.all(data["labels"] == np.array(["d", "e", "c", "a", "b", "f"])) + assert data["name"] == "abdce" + + +def test_chunk_concat(): + obs = torch.tensor([1, 2, 3, 4, 5, 6]) + labels = ["a", "b", "c", "d", "e", "f"] + data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"name": "abcde"}) + + data_split = data.tensor_split(indices_or_sections=5, dim=0) + + expected_idx_lst = [[0, 1], [2], [3], [4], [5]] + + for d, expected_idx in zip(data_split, expected_idx_lst, strict=False): + tu.assert_tensordict_eq(d, data[expected_idx]) + + data_split = data.chunk(2) + assert len(data_split) == 2 + assert torch.all(torch.eq(data_split[0]["obs"], torch.tensor([1, 2, 3]))) + assert np.all(data_split[0]["labels"] == np.array(["a", "b", "c"])) + assert data_split[0]["name"] == "abcde" + + assert torch.all(torch.eq(data_split[1]["obs"], torch.tensor([4, 5, 6]))) + assert np.all(data_split[1]["labels"] == np.array(["d", "e", "f"])) + assert data_split[1]["name"] == "abcde" + + concat_data = torch.cat(data_split, dim=0) + assert torch.all(torch.eq(concat_data["obs"], data["obs"])) + assert np.all(concat_data["labels"] == data["labels"]) + assert concat_data["name"] == data["name"] + + +def test_pop(): + obs = torch.randn(100, 10) + act = torch.randn(100, 3) + dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1}) + + poped_dataset = tu.pop(dataset, keys=["obs", "2"]) + + assert poped_dataset.batch_size[0] == 100 + + assert poped_dataset.keys() == {"obs", "2"} + + assert dataset.keys() == {"act", "1"} + + +def test_repeat(): + # Create a DataProto object with some batch and non-tensor data + obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) + labels = ["a", "b", "c"] + data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) + + # Test interleave=True + repeated_data_interleave = data.repeat_interleave(repeats=2) + expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]]) + expected_labels_interleave = ["a", "a", "b", "b", "c", "c"] + + assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave)) + assert repeated_data_interleave["labels"] == expected_labels_interleave + assert repeated_data_interleave["info"] == "test_info" + + # Test interleave=False + repeated_data_no_interleave = data.repeat(2) + expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]]) + expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"] + + assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave)) + assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave + assert repeated_data_no_interleave["info"] == "test_info" + + +def test_dataproto_pad_unpad(): + obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) + labels = ["a", "b", "c"] + data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) + + padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2) + + assert pad_size == 1 + + expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]]) + expected_labels = ["a", "b", "c", "a"] + + assert torch.all(torch.eq(padded_data["obs"], expected_obs)) + assert padded_data["labels"] == expected_labels + assert padded_data["info"] == "test_info" + + unpadd_data = tu.unpad(padded_data, pad_size=pad_size) + assert torch.all(torch.eq(unpadd_data["obs"], obs)) + assert unpadd_data["labels"] == labels + assert unpadd_data["info"] == "test_info" + + padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3) + assert pad_size == 0 + + expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) + expected_labels = ["a", "b", "c"] + + assert torch.all(torch.eq(padded_data["obs"], expected_obs)) + assert padded_data["labels"] == expected_labels + assert padded_data["info"] == "test_info" + + unpadd_data = tu.unpad(padded_data, pad_size=pad_size) + assert torch.all(torch.eq(unpadd_data["obs"], obs)) + assert unpadd_data["labels"] == labels + assert unpadd_data["info"] == "test_info" + + padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7) + assert pad_size == 4 + + expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]]) + expected_labels = ["a", "b", "c", "a", "b", "c", "a"] + assert torch.all(torch.eq(padded_data["obs"], expected_obs)) + assert padded_data["labels"] == expected_labels + assert padded_data["info"] == "test_info" + + unpadd_data = tu.unpad(padded_data, pad_size=pad_size) + assert torch.all(torch.eq(unpadd_data["obs"], obs)) + assert unpadd_data["labels"] == labels + assert unpadd_data["info"] == "test_info" + + +def test_torch_save_data_proto(): + obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) + labels = ["a", "b", "c"] + data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) + + filename = "test_data.pt" + torch.save(data, filename) + loaded_data = torch.load(filename, weights_only=False) + + assert torch.all(torch.eq(loaded_data["obs"], data["obs"])) + assert loaded_data["labels"] == data["labels"] + assert loaded_data["info"] == data["info"] + + import os + + os.remove(filename) + + +def test_len(): + obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) + labels = np.array(["a", "b", "c"], dtype=object) + + data = tu.get_tensordict({"obs": obs, "labels": labels.tolist()}, non_tensor_dict={"info": "test_info"}) + assert len(data) == 3 + + data = tu.get_tensordict({"labels": labels.tolist()}, non_tensor_dict={"info": "test_info"}) + assert len(data) == 3 + + data_item = data[0] + assert len(data_item) == 0 + + data = tu.get_tensordict({}, non_tensor_dict={"info": "test_info"}) + assert len(data) == 0 + + +def test_dataproto_index(): + data_len = 100 + idx_num = 10 + + obs = torch.randn(data_len, 10) + labels = [random.choice(["abc", "cde"]) for _ in range(data_len)] + + data = tu.get_tensordict({"obs": obs, "labels": labels}) + + labels_np = np.array(labels) + + idx_np_int = np.random.randint(0, data_len, size=(idx_num,)) + result_np_int = data[idx_np_int] + assert result_np_int.keys() == data.keys() + assert result_np_int["obs"].shape[0] == idx_num + assert len(result_np_int["labels"]) == idx_num + assert np.array_equal(result_np_int["obs"].cpu().numpy(), obs[idx_np_int].numpy()) + assert np.array_equal(result_np_int["labels"], labels_np[idx_np_int]) + + idx_torch_int = torch.randint(0, data_len, size=(idx_num,)) + result_torch_int = data[idx_torch_int] + assert result_torch_int.keys() == data.keys() + assert result_torch_int["obs"].shape[0] == idx_num + assert len(result_torch_int["labels"]) == idx_num + assert np.array_equal(result_torch_int["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy()) + assert np.array_equal(result_torch_int["labels"], labels_np[idx_torch_int.cpu().numpy()]) + + idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)] + result_list_int = data[idx_list_int] + assert result_list_int.keys() == data.keys() + assert result_list_int["obs"].shape[0] == idx_num + assert len(result_list_int["labels"]) == idx_num + assert np.array_equal(result_list_int["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy()) + assert np.array_equal(result_list_int["labels"], labels_np[idx_list_int]) + + # idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool) + # result_np_bool = data[idx_np_bool] + # assert result_np_bool.keys() == data.keys() + # assert result_np_bool["obs"].shape[0] == idx_np_bool.sum() + # assert len(result_np_bool["labels"]) == idx_np_bool.sum() + # assert np.array_equal(result_np_bool["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy()) + # assert np.array_equal(result_np_bool["labels"], labels_np[idx_np_bool]) + + idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool) + result_torch_bool = data[idx_torch_bool] + assert result_torch_bool.keys() == data.keys() + assert result_torch_bool["obs"].shape[0] == idx_torch_bool.sum().item() + assert len(result_torch_bool["labels"]) == idx_torch_bool.sum().item() + assert np.array_equal(result_torch_bool["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy()) + assert np.array_equal(result_torch_bool["labels"], labels_np[idx_torch_bool]) + + # idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)] + # result_list_bool = data[idx_list_bool] + # assert result_list_bool.keys() == data.keys() + # assert result_list_bool["obs"].shape[0] == sum(idx_list_bool) + # assert len(result_list_bool["labels"]) == sum(idx_list_bool) + # assert np.array_equal(result_list_bool["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy()) + # assert np.array_equal(result_list_bool["labels"], labels_np[idx_list_bool]) + + +def test_select(): + obs = torch.randn(100, 10) + act = torch.randn(100, 3) + dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1}) + + subset = dataset.select("obs", "2") + + assert torch.all(torch.eq(subset["obs"], dataset["obs"])) + assert subset["2"] == dataset["2"] + assert "act" not in subset.keys() + assert "1" not in subset.keys() + + +def test_dataproto_no_batch(): + labels = ["a", "b", "c"] + data = tu.get_tensordict(tensor_dict={"labels": labels}, non_tensor_dict={"info": "test_info"}) + selected = data.select("labels") + + assert selected["labels"] == labels + pop_data = tu.pop(data, keys=["labels"]) + assert pop_data["labels"] == labels + assert "labels" not in data + + +def test_sample_level_repeat(): + # Create a DataProto object with some batch and non-tensor data + obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) + labels = ["a", "b", "c"] + + data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) + + # list + repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2])) + expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]]) + expected_labels_interleave = ["a", "a", "a", "b", "c", "c"] + + assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave)) + assert repeated_data_interleave["labels"] == expected_labels_interleave + assert repeated_data_interleave["info"] == "test_info" + + # torch.tensor + repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3])) + expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]]) + expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"] + + assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave)) + assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave + assert repeated_data_no_interleave["info"] == "test_info" + + +def test_dataproto_chunk_after_index(): + data_len = 4 + obs = torch.randn(data_len, 4) + labels = [f"label_{i}" for i in range(data_len)] + + data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"name": "abc"}) + # Test with boolean numpy array + bool_mask = torch.tensor([True, False, True, False]) + selected = data[bool_mask] + assert isinstance(selected.batch_size, torch.Size) + assert all(isinstance(d, int) for d in selected.batch_size) # int or List[int] + + # Test with integer numpy array + int_mask = torch.tensor([0, 2]) + selected = data[int_mask] + assert isinstance(selected.batch_size, torch.Size) + assert all(isinstance(d, int) for d in selected.batch_size) + + # Test with boolean list + list_mask = [True, False, True, False] + selected = data[list_mask] + assert isinstance(selected.batch_size, torch.Size) + assert all(isinstance(d, int) for d in selected.batch_size) + + # Test with list + list_mask = [0, 2] + selected = data[list_mask] + assert isinstance(selected.batch_size, torch.Size) + assert all(isinstance(d, int) for d in selected.batch_size) + + # Test with torch tensor (bool) + torch_bool_mask = torch.tensor([True, False, True, False]) + selected = data[torch_bool_mask] + assert isinstance(selected.batch_size, torch.Size) + assert all(isinstance(d, int) for d in selected.batch_size) + + # Test with torch tensor (int) + torch_int_mask = torch.tensor([0, 2]) + selected = data[torch_int_mask] + assert isinstance(selected.batch_size, torch.Size) + assert all(isinstance(d, int) for d in selected.batch_size) diff --git a/verl/__init__.py b/verl/__init__.py index 18d741a31..38f2e7cf9 100644 --- a/verl/__init__.py +++ b/verl/__init__.py @@ -74,16 +74,19 @@ if is_npu_available: # for third-party devices such as NPUs. This patch fixes this issue, and the relevant # modifications can be removed once the fix is merged into tensordict. - from tensordict.base import TensorDictBase + import tensordict - def _sync_all_patch(self): - from torch._utils import _get_available_device_type, _get_device_module + if parse_version(tensordict.__version__) < parse_version("0.10.0"): + from tensordict.base import TensorDictBase - device_type = _get_available_device_type() - if device_type is None: - return + def _sync_all_patch(self): + from torch._utils import _get_available_device_type, _get_device_module - device_module = _get_device_module(device_type) - device_module.synchronize() + device_type = _get_available_device_type() + if device_type is None: + return - TensorDictBase._sync_all = _sync_all_patch + device_module = _get_device_module(device_type) + device_module.synchronize() + + TensorDictBase._sync_all = _sync_all_patch diff --git a/verl/protocol.py b/verl/protocol.py index 7dee1fc19..8e2b7f996 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -31,6 +31,7 @@ import tensordict import torch import torch.distributed from packaging import version +from packaging.version import parse as parse_version from tensordict import TensorDict from torch.utils.data import DataLoader @@ -42,6 +43,8 @@ __all__ = ["DataProto", "union_tensor_dict"] with contextlib.suppress(Exception): tensordict.set_lazy_legacy(False).set() + if parse_version(tensordict.__version__) < parse_version("0.10.0"): + tensordict.set_list_to_stack(True).set() class _DataProtoConfigMeta(type): @@ -964,6 +967,29 @@ class DataProto: meta_info=self.meta_info, ) + def to_tensordict(self) -> TensorDict: + """Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10 + + Returns: + + """ + assert parse_version(tensordict.__version__) >= parse_version("0.10"), ( + "Convert DataProto to TensorDict at least requires tensordict version 0.10" + ) + tensor_batch = self.batch.to_dict() + non_tensor_batch = self.non_tensor_batch + + from verl.utils import tensordict_utils as tu + + common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys()) + assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}" + + for key, val in non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + tensor_batch[key] = val.tolist() + output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info) + return output + def get_data_info(self) -> str: """Return formatted information about stored data with nested type details. diff --git a/verl/utils/tensordict_utils.py b/verl/utils/tensordict_utils.py new file mode 100644 index 000000000..3e3f015ef --- /dev/null +++ b/verl/utils/tensordict_utils.py @@ -0,0 +1,186 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Iterator + +import torch +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorData, NonTensorStack + + +def assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict): + for key, val in non_tensor_dict.items(): + assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val) + return tensor_dict + + +def assign_non_tensor_data(tensor_dict: TensorDict, key, val): + tensor_dict[key] = NonTensorData(val) + + +def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict: + """ + + Args: + data_dict: + meta_info: + + Returns: + + """ + if non_tensor_dict is None: + non_tensor_dict = {} + + batch_size = None + + for key, val in tensor_dict.items(): + if isinstance(val, list): + for v in val: + assert not isinstance(v, torch.Tensor), ( + "Passing a list makes the data NonTensorStack, " + "which doesn't support torch.Tensor. Please convert to numpy first" + ) + + assert isinstance(val, torch.Tensor | list) + + if batch_size is None: + batch_size = len(val) + else: + assert len(val) == batch_size + + if batch_size is None: + batch_size = [] + else: + batch_size = [batch_size] + + for key, val in non_tensor_dict.items(): + assert key not in tensor_dict + tensor_dict[key] = NonTensorData(val) + + return TensorDict(source=tensor_dict, batch_size=batch_size) + + +def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: + """Union two tensordicts.""" + assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( + f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + ) + for key in tensor_dict2.keys(): + if key not in tensor_dict1.keys(): + tensor_dict1[key] = tensor_dict2[key] + else: + if isinstance(tensor_dict2[key], torch.Tensor): + assert tensor_dict1[key].equal(tensor_dict2[key]), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) + else: + # non-tensor + assert tensor_dict1[key] == tensor_dict2[key], ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) + + return tensor_dict1 + + +def make_iterator(tensordict: TensorDict, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): + from torch.utils.data import DataLoader + + assert tensordict.batch_size[0] % mini_batch_size == 0, f"{tensordict.batch_size[0]} % {mini_batch_size} != 0" + # we can directly create a dataloader from TensorDict + if dataloader_kwargs is None: + dataloader_kwargs = {} + + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = None + + assert isinstance(dataloader_kwargs, dict) + train_dataloader = DataLoader( + dataset=tensordict, batch_size=mini_batch_size, collate_fn=lambda x: x, generator=generator, **dataloader_kwargs + ) + + def get_data(): + for _ in range(epochs): + yield from train_dataloader + + return iter(get_data()) + + +def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict): + assert set(tensordict1.keys()) == set(tensordict2.keys()) + + for key in tensordict1.keys(): + val = tensordict1[key] + val2 = tensordict2[key] + + assert type(val) is type(val2), f"The type of {key} must be the same. Got {type(val)} vs {type(val2)}" + + if isinstance(val, torch.Tensor): + assert torch.all(torch.eq(val, val2)).item() + else: + assert val == val2 + + +def pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict: + tensor_output = {} + non_tensor_output = {} + for key in keys: + output = tensordict.get(key) + if isinstance(output, torch.Tensor): + tensor_output[key] = tensordict.pop(key) + elif isinstance(output, NonTensorStack): + tensor_output[key] = tensordict.pop(key).tolist() + else: + assert isinstance(output, NonTensorData) + non_tensor_output[key] = tensordict.pop(key) + + return get_tensordict(tensor_output, non_tensor_output) + + +def pad_to_divisor(data: TensorDict, size_divisor: int): + """Pad a TensorDict to size divisible by size_divisor + + Args: + size_divisor (int): size divisor + + Returns: + data: (TensorDict): the padded TensorDict + pad_size (int) + """ + assert isinstance(data, TensorDict), "data must be a TensorDict" + if len(data) % size_divisor != 0: + pad_size = size_divisor - len(data) % size_divisor + padding_protos = [] + remaining_pad = pad_size + while remaining_pad > 0: + take_size = min(remaining_pad, len(data)) + padding_protos.append(data[:take_size]) + remaining_pad -= take_size + data_padded = torch.cat([data] + padding_protos) + else: + if len(data) == 0: + logging.warning("padding a DataProto with no item, no changed made") + pad_size = 0 + data_padded = data + return data_padded, pad_size + + +def unpad(data: TensorDict, pad_size): + """Unpad the data proto with pad_size. i.e. `data[:-pad_size]`""" + if pad_size != 0: + data = data[:-pad_size] + return data