mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 1 (#2733)
### What does this PR do? - Add TensorDict utilities and tests to cover the current DataProto functionalities. - Add nested tensor example to remove padding throughout the system - Add image example - Upgrade tensordict to v0.10 ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -10,7 +10,7 @@ peft>=0.15.2
|
|||||||
pyarrow>=15.0.0
|
pyarrow>=15.0.0
|
||||||
pybind11
|
pybind11
|
||||||
pylatexenc
|
pylatexenc
|
||||||
tensordict>=0.8.0,<=0.9.1,!=0.9.0
|
tensordict>=0.8.0,<=0.10.0,!=0.9.0
|
||||||
transformers==4.52.4
|
transformers==4.52.4
|
||||||
ray==2.46.0
|
ray==2.46.0
|
||||||
wandb
|
wandb
|
||||||
|
@ -14,7 +14,7 @@ pybind11
|
|||||||
pylatexenc
|
pylatexenc
|
||||||
pre-commit
|
pre-commit
|
||||||
ray[default]
|
ray[default]
|
||||||
tensordict>=0.8.0,<=0.9.1,!=0.9.0
|
tensordict>=0.8.0,<=0.10.0,!=0.9.0
|
||||||
torchdata
|
torchdata
|
||||||
transformers
|
transformers
|
||||||
# vllm==0.8.4
|
# vllm==0.8.4
|
||||||
|
@ -12,7 +12,7 @@ pyarrow>=19.0.0
|
|||||||
pybind11
|
pybind11
|
||||||
pylatexenc
|
pylatexenc
|
||||||
ray[default]>=2.10
|
ray[default]>=2.10
|
||||||
tensordict>=0.8.0,<=0.9.1,!=0.9.0
|
tensordict>=0.8.0,<=0.10.0,!=0.9.0
|
||||||
torchdata
|
torchdata
|
||||||
torchvision
|
torchvision
|
||||||
transformers
|
transformers
|
||||||
|
6
setup.py
6
setup.py
@ -37,7 +37,7 @@ install_requires = [
|
|||||||
"pylatexenc",
|
"pylatexenc",
|
||||||
"ray[default]>=2.41.0",
|
"ray[default]>=2.41.0",
|
||||||
"torchdata",
|
"torchdata",
|
||||||
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
|
"tensordict>=0.8.0,<=0.10.0,!=0.9.0",
|
||||||
"transformers",
|
"transformers",
|
||||||
"wandb",
|
"wandb",
|
||||||
"packaging>=20.0",
|
"packaging>=20.0",
|
||||||
@ -49,9 +49,9 @@ PRIME_REQUIRES = ["pyext"]
|
|||||||
GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"]
|
GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"]
|
||||||
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
|
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
|
||||||
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
|
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
|
||||||
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.9.1"]
|
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.10.0,!=0.9.0", "vllm>=0.7.3,<=0.9.1"]
|
||||||
SGLANG_REQUIRES = [
|
SGLANG_REQUIRES = [
|
||||||
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
|
"tensordict>=0.8.0,<=0.10.0,!=0.9.0",
|
||||||
"sglang[srt,openai]==0.4.10.post2",
|
"sglang[srt,openai]==0.4.10.post2",
|
||||||
"torch==2.7.1",
|
"torch==2.7.1",
|
||||||
]
|
]
|
||||||
|
@ -86,7 +86,11 @@ def main() -> None:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--allow-files",
|
"--allow-files",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
default=["tests/test_protocol_on_cpu.py", "tests/test_base_config_on_cpu.py"],
|
default=[
|
||||||
|
"tests/test_protocol_on_cpu.py",
|
||||||
|
"tests/test_base_config_on_cpu.py",
|
||||||
|
"tests/test_protocol_v2_on_cpu.py",
|
||||||
|
],
|
||||||
help="Extra top-level test folders that are exempt from the rule",
|
help="Extra top-level test folders that are exempt from the rule",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -16,7 +16,9 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import tensordict
|
||||||
import torch
|
import torch
|
||||||
|
from packaging.version import parse as parse_version
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
|
||||||
from verl import DataProto
|
from verl import DataProto
|
||||||
@ -598,3 +600,17 @@ def test_dataproto_chunk_after_index():
|
|||||||
selected = data[torch_int_mask]
|
selected = data[torch_int_mask]
|
||||||
assert isinstance(selected.batch.batch_size, torch.Size)
|
assert isinstance(selected.batch.batch_size, torch.Size)
|
||||||
assert all(isinstance(d, int) for d in selected.batch.batch_size)
|
assert all(isinstance(d, int) for d in selected.batch.batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
|
||||||
|
)
|
||||||
|
def test_to_tensordict():
|
||||||
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
labels = ["a", "b", "c", "d", "e", "f"]
|
||||||
|
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})
|
||||||
|
output = data.to_tensordict()
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(output["obs"], obs)).item()
|
||||||
|
assert output["labels"] == labels
|
||||||
|
assert output["name"] == "abdce"
|
||||||
|
525
tests/test_protocol_v2_on_cpu.py
Normal file
525
tests/test_protocol_v2_on_cpu.py
Normal file
@ -0,0 +1,525 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Replace DataProto with raw TensorDict
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from verl.utils import tensordict_utils as tu
|
||||||
|
|
||||||
|
|
||||||
|
def test_union_tensor_dict():
|
||||||
|
obs = torch.randn(100, 10)
|
||||||
|
|
||||||
|
meta_info1 = {"top_p": 0.8}
|
||||||
|
meta_info2 = {"top_p": 0.9}
|
||||||
|
data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
|
||||||
|
data2 = {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100), "data_sources": ["gsm8k"] * 100}
|
||||||
|
|
||||||
|
data_with_copied_obs = {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}
|
||||||
|
|
||||||
|
data1 = tu.get_tensordict(tensor_dict=data1)
|
||||||
|
data2 = tu.get_tensordict(tensor_dict=data2)
|
||||||
|
data_with_copied_obs = tu.get_tensordict(data_with_copied_obs)
|
||||||
|
|
||||||
|
tu.union_tensor_dict(data1, data2)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# conflict in tensor values
|
||||||
|
tu.union_tensor_dict(data1, data_with_copied_obs)
|
||||||
|
|
||||||
|
data1 = tu.assign_non_tensor_dict(data1, meta_info1)
|
||||||
|
tu.union_tensor_dict(data1, data2) # works ok
|
||||||
|
|
||||||
|
data2 = tu.assign_non_tensor_dict(data2, meta_info2)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# conflict in NonTensorData
|
||||||
|
tu.union_tensor_dict(data1, data2)
|
||||||
|
|
||||||
|
data1.pop("top_p")
|
||||||
|
data2.pop("top_p")
|
||||||
|
|
||||||
|
data2["data_sources"][0] = "math"
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# conflict in NonTensorData
|
||||||
|
tu.union_tensor_dict(data1, data2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_dict_constructor():
|
||||||
|
obs = torch.ones(100, 10)
|
||||||
|
act = torch.zeros(100, 10, 3)
|
||||||
|
data_source = ["gsm8k"] * 100
|
||||||
|
non_tensor_dict = {"name": "abdce"}
|
||||||
|
|
||||||
|
data = tu.get_tensordict(
|
||||||
|
tensor_dict={"obs": obs, "act": act, "data_source": data_source}, non_tensor_dict=non_tensor_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
assert data.batch_size == torch.Size([100])
|
||||||
|
|
||||||
|
# test slicing
|
||||||
|
assert torch.all(torch.eq(data[0]["obs"], torch.ones(10))).item()
|
||||||
|
assert torch.all(torch.eq(data[0]["act"], torch.zeros(10, 3))).item()
|
||||||
|
assert data[0]["data_source"] == "gsm8k"
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(data[0:2]["obs"], torch.ones(2, 10))).item()
|
||||||
|
assert torch.all(torch.eq(data[0:2]["act"], torch.zeros(2, 10, 3))).item()
|
||||||
|
assert data[0:2]["data_source"] == ["gsm8k"] * 2
|
||||||
|
|
||||||
|
# test non tensor data
|
||||||
|
assert data["name"] == "abdce"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensordict_with_images():
|
||||||
|
# each sample contains a sequence with multiple images of different sizes
|
||||||
|
vocab_size = 128
|
||||||
|
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
||||||
|
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
||||||
|
input_ids = [a, b]
|
||||||
|
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
||||||
|
|
||||||
|
# must be numpy
|
||||||
|
# TODO(vermouth1992). We may use nested tensor too. But this requires nested over nested
|
||||||
|
a_images = [
|
||||||
|
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
||||||
|
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
||||||
|
]
|
||||||
|
b_images = [
|
||||||
|
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
||||||
|
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
||||||
|
torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
|
||||||
|
]
|
||||||
|
|
||||||
|
images = [a_images, b_images]
|
||||||
|
|
||||||
|
data = tu.get_tensordict({"input_ids": input_ids, "images": images})
|
||||||
|
|
||||||
|
assert np.all(np.equal(data[0]["images"][0], a_images[0]))
|
||||||
|
assert torch.all(torch.eq(data[0]["input_ids"], a))
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensordict_with_packing():
|
||||||
|
vocab_size = 128
|
||||||
|
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
||||||
|
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
||||||
|
input_ids = [a, b]
|
||||||
|
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
||||||
|
|
||||||
|
data = tu.get_tensordict({"input_ids": input_ids})
|
||||||
|
|
||||||
|
# test cu_seqlens
|
||||||
|
cu_seqlens = torch.tensor([0, 11, 24])
|
||||||
|
assert torch.all(torch.eq(cu_seqlens, data["input_ids"].offsets()))
|
||||||
|
|
||||||
|
# test index
|
||||||
|
assert torch.all(torch.eq(data["input_ids"][0], a))
|
||||||
|
assert torch.all(torch.eq(data["input_ids"][1], b))
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(data[0]["input_ids"], a))
|
||||||
|
assert torch.all(torch.eq(data[1]["input_ids"], b))
|
||||||
|
|
||||||
|
data_lst = data.chunk(2)
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
|
||||||
|
assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensordict_eq():
|
||||||
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
||||||
|
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
||||||
|
data = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
||||||
|
|
||||||
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
||||||
|
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
||||||
|
data1 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
||||||
|
|
||||||
|
tu.assert_tensordict_eq(data, data1)
|
||||||
|
|
||||||
|
data2 = copy.deepcopy(data1)
|
||||||
|
data2["obs"][0] += 1
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
tu.assert_tensordict_eq(data, data2)
|
||||||
|
|
||||||
|
data2 = copy.deepcopy(data1)
|
||||||
|
data2["data_sources"][0] = "math"
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
tu.assert_tensordict_eq(data, data2)
|
||||||
|
|
||||||
|
data2 = copy.deepcopy(data1)
|
||||||
|
data2["train_sample_kwargs"]["top_p"] = 0.9
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
tu.assert_tensordict_eq(data, data2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_dict_make_iterator():
|
||||||
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
||||||
|
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
||||||
|
dataset = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
||||||
|
|
||||||
|
dataloader = tu.make_iterator(
|
||||||
|
dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={"shuffle": False, "drop_last": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]]
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
for d in dataloader:
|
||||||
|
tu.assert_tensordict_eq(d, expected_tensor_dict[i])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
|
||||||
|
data_list_1 = []
|
||||||
|
for data in data_iter_1:
|
||||||
|
data_list_1.append(data)
|
||||||
|
|
||||||
|
data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
|
||||||
|
data_list_2 = []
|
||||||
|
for data in data_iter_2:
|
||||||
|
data_list_2.append(data)
|
||||||
|
|
||||||
|
for data1, data2 in zip(data_list_1, data_list_2, strict=True):
|
||||||
|
tu.assert_tensordict_eq(data1, data2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reorder():
|
||||||
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
labels = ["a", "b", "c", "d", "e", "f"]
|
||||||
|
non_tensor_dict = {"name": "abdce"}
|
||||||
|
|
||||||
|
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict=non_tensor_dict)
|
||||||
|
data = data[torch.tensor([3, 4, 2, 0, 1, 5])]
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(data["obs"], torch.tensor([4, 5, 3, 1, 2, 6])))
|
||||||
|
assert np.all(data["labels"] == np.array(["d", "e", "c", "a", "b", "f"]))
|
||||||
|
assert data["name"] == "abdce"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_concat():
|
||||||
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
labels = ["a", "b", "c", "d", "e", "f"]
|
||||||
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"name": "abcde"})
|
||||||
|
|
||||||
|
data_split = data.tensor_split(indices_or_sections=5, dim=0)
|
||||||
|
|
||||||
|
expected_idx_lst = [[0, 1], [2], [3], [4], [5]]
|
||||||
|
|
||||||
|
for d, expected_idx in zip(data_split, expected_idx_lst, strict=False):
|
||||||
|
tu.assert_tensordict_eq(d, data[expected_idx])
|
||||||
|
|
||||||
|
data_split = data.chunk(2)
|
||||||
|
assert len(data_split) == 2
|
||||||
|
assert torch.all(torch.eq(data_split[0]["obs"], torch.tensor([1, 2, 3])))
|
||||||
|
assert np.all(data_split[0]["labels"] == np.array(["a", "b", "c"]))
|
||||||
|
assert data_split[0]["name"] == "abcde"
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(data_split[1]["obs"], torch.tensor([4, 5, 6])))
|
||||||
|
assert np.all(data_split[1]["labels"] == np.array(["d", "e", "f"]))
|
||||||
|
assert data_split[1]["name"] == "abcde"
|
||||||
|
|
||||||
|
concat_data = torch.cat(data_split, dim=0)
|
||||||
|
assert torch.all(torch.eq(concat_data["obs"], data["obs"]))
|
||||||
|
assert np.all(concat_data["labels"] == data["labels"])
|
||||||
|
assert concat_data["name"] == data["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pop():
|
||||||
|
obs = torch.randn(100, 10)
|
||||||
|
act = torch.randn(100, 3)
|
||||||
|
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
|
||||||
|
|
||||||
|
poped_dataset = tu.pop(dataset, keys=["obs", "2"])
|
||||||
|
|
||||||
|
assert poped_dataset.batch_size[0] == 100
|
||||||
|
|
||||||
|
assert poped_dataset.keys() == {"obs", "2"}
|
||||||
|
|
||||||
|
assert dataset.keys() == {"act", "1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeat():
|
||||||
|
# Create a DataProto object with some batch and non-tensor data
|
||||||
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
||||||
|
labels = ["a", "b", "c"]
|
||||||
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
||||||
|
|
||||||
|
# Test interleave=True
|
||||||
|
repeated_data_interleave = data.repeat_interleave(repeats=2)
|
||||||
|
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])
|
||||||
|
expected_labels_interleave = ["a", "a", "b", "b", "c", "c"]
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
|
||||||
|
assert repeated_data_interleave["labels"] == expected_labels_interleave
|
||||||
|
assert repeated_data_interleave["info"] == "test_info"
|
||||||
|
|
||||||
|
# Test interleave=False
|
||||||
|
repeated_data_no_interleave = data.repeat(2)
|
||||||
|
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])
|
||||||
|
expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"]
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
|
||||||
|
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
|
||||||
|
assert repeated_data_no_interleave["info"] == "test_info"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataproto_pad_unpad():
|
||||||
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
||||||
|
labels = ["a", "b", "c"]
|
||||||
|
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
||||||
|
|
||||||
|
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2)
|
||||||
|
|
||||||
|
assert pad_size == 1
|
||||||
|
|
||||||
|
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])
|
||||||
|
expected_labels = ["a", "b", "c", "a"]
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
|
||||||
|
assert padded_data["labels"] == expected_labels
|
||||||
|
assert padded_data["info"] == "test_info"
|
||||||
|
|
||||||
|
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
|
||||||
|
assert torch.all(torch.eq(unpadd_data["obs"], obs))
|
||||||
|
assert unpadd_data["labels"] == labels
|
||||||
|
assert unpadd_data["info"] == "test_info"
|
||||||
|
|
||||||
|
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3)
|
||||||
|
assert pad_size == 0
|
||||||
|
|
||||||
|
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
||||||
|
expected_labels = ["a", "b", "c"]
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
|
||||||
|
assert padded_data["labels"] == expected_labels
|
||||||
|
assert padded_data["info"] == "test_info"
|
||||||
|
|
||||||
|
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
|
||||||
|
assert torch.all(torch.eq(unpadd_data["obs"], obs))
|
||||||
|
assert unpadd_data["labels"] == labels
|
||||||
|
assert unpadd_data["info"] == "test_info"
|
||||||
|
|
||||||
|
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7)
|
||||||
|
assert pad_size == 4
|
||||||
|
|
||||||
|
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
|
||||||
|
expected_labels = ["a", "b", "c", "a", "b", "c", "a"]
|
||||||
|
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
|
||||||
|
assert padded_data["labels"] == expected_labels
|
||||||
|
assert padded_data["info"] == "test_info"
|
||||||
|
|
||||||
|
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
|
||||||
|
assert torch.all(torch.eq(unpadd_data["obs"], obs))
|
||||||
|
assert unpadd_data["labels"] == labels
|
||||||
|
assert unpadd_data["info"] == "test_info"
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_save_data_proto():
|
||||||
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
||||||
|
labels = ["a", "b", "c"]
|
||||||
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
||||||
|
|
||||||
|
filename = "test_data.pt"
|
||||||
|
torch.save(data, filename)
|
||||||
|
loaded_data = torch.load(filename, weights_only=False)
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(loaded_data["obs"], data["obs"]))
|
||||||
|
assert loaded_data["labels"] == data["labels"]
|
||||||
|
assert loaded_data["info"] == data["info"]
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def test_len():
|
||||||
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
||||||
|
labels = np.array(["a", "b", "c"], dtype=object)
|
||||||
|
|
||||||
|
data = tu.get_tensordict({"obs": obs, "labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
|
||||||
|
assert len(data) == 3
|
||||||
|
|
||||||
|
data = tu.get_tensordict({"labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
|
||||||
|
assert len(data) == 3
|
||||||
|
|
||||||
|
data_item = data[0]
|
||||||
|
assert len(data_item) == 0
|
||||||
|
|
||||||
|
data = tu.get_tensordict({}, non_tensor_dict={"info": "test_info"})
|
||||||
|
assert len(data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataproto_index():
|
||||||
|
data_len = 100
|
||||||
|
idx_num = 10
|
||||||
|
|
||||||
|
obs = torch.randn(data_len, 10)
|
||||||
|
labels = [random.choice(["abc", "cde"]) for _ in range(data_len)]
|
||||||
|
|
||||||
|
data = tu.get_tensordict({"obs": obs, "labels": labels})
|
||||||
|
|
||||||
|
labels_np = np.array(labels)
|
||||||
|
|
||||||
|
idx_np_int = np.random.randint(0, data_len, size=(idx_num,))
|
||||||
|
result_np_int = data[idx_np_int]
|
||||||
|
assert result_np_int.keys() == data.keys()
|
||||||
|
assert result_np_int["obs"].shape[0] == idx_num
|
||||||
|
assert len(result_np_int["labels"]) == idx_num
|
||||||
|
assert np.array_equal(result_np_int["obs"].cpu().numpy(), obs[idx_np_int].numpy())
|
||||||
|
assert np.array_equal(result_np_int["labels"], labels_np[idx_np_int])
|
||||||
|
|
||||||
|
idx_torch_int = torch.randint(0, data_len, size=(idx_num,))
|
||||||
|
result_torch_int = data[idx_torch_int]
|
||||||
|
assert result_torch_int.keys() == data.keys()
|
||||||
|
assert result_torch_int["obs"].shape[0] == idx_num
|
||||||
|
assert len(result_torch_int["labels"]) == idx_num
|
||||||
|
assert np.array_equal(result_torch_int["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())
|
||||||
|
assert np.array_equal(result_torch_int["labels"], labels_np[idx_torch_int.cpu().numpy()])
|
||||||
|
|
||||||
|
idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]
|
||||||
|
result_list_int = data[idx_list_int]
|
||||||
|
assert result_list_int.keys() == data.keys()
|
||||||
|
assert result_list_int["obs"].shape[0] == idx_num
|
||||||
|
assert len(result_list_int["labels"]) == idx_num
|
||||||
|
assert np.array_equal(result_list_int["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy())
|
||||||
|
assert np.array_equal(result_list_int["labels"], labels_np[idx_list_int])
|
||||||
|
|
||||||
|
# idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)
|
||||||
|
# result_np_bool = data[idx_np_bool]
|
||||||
|
# assert result_np_bool.keys() == data.keys()
|
||||||
|
# assert result_np_bool["obs"].shape[0] == idx_np_bool.sum()
|
||||||
|
# assert len(result_np_bool["labels"]) == idx_np_bool.sum()
|
||||||
|
# assert np.array_equal(result_np_bool["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())
|
||||||
|
# assert np.array_equal(result_np_bool["labels"], labels_np[idx_np_bool])
|
||||||
|
|
||||||
|
idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)
|
||||||
|
result_torch_bool = data[idx_torch_bool]
|
||||||
|
assert result_torch_bool.keys() == data.keys()
|
||||||
|
assert result_torch_bool["obs"].shape[0] == idx_torch_bool.sum().item()
|
||||||
|
assert len(result_torch_bool["labels"]) == idx_torch_bool.sum().item()
|
||||||
|
assert np.array_equal(result_torch_bool["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())
|
||||||
|
assert np.array_equal(result_torch_bool["labels"], labels_np[idx_torch_bool])
|
||||||
|
|
||||||
|
# idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]
|
||||||
|
# result_list_bool = data[idx_list_bool]
|
||||||
|
# assert result_list_bool.keys() == data.keys()
|
||||||
|
# assert result_list_bool["obs"].shape[0] == sum(idx_list_bool)
|
||||||
|
# assert len(result_list_bool["labels"]) == sum(idx_list_bool)
|
||||||
|
# assert np.array_equal(result_list_bool["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())
|
||||||
|
# assert np.array_equal(result_list_bool["labels"], labels_np[idx_list_bool])
|
||||||
|
|
||||||
|
|
||||||
|
def test_select():
|
||||||
|
obs = torch.randn(100, 10)
|
||||||
|
act = torch.randn(100, 3)
|
||||||
|
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
|
||||||
|
|
||||||
|
subset = dataset.select("obs", "2")
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(subset["obs"], dataset["obs"]))
|
||||||
|
assert subset["2"] == dataset["2"]
|
||||||
|
assert "act" not in subset.keys()
|
||||||
|
assert "1" not in subset.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataproto_no_batch():
|
||||||
|
labels = ["a", "b", "c"]
|
||||||
|
data = tu.get_tensordict(tensor_dict={"labels": labels}, non_tensor_dict={"info": "test_info"})
|
||||||
|
selected = data.select("labels")
|
||||||
|
|
||||||
|
assert selected["labels"] == labels
|
||||||
|
pop_data = tu.pop(data, keys=["labels"])
|
||||||
|
assert pop_data["labels"] == labels
|
||||||
|
assert "labels" not in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_level_repeat():
|
||||||
|
# Create a DataProto object with some batch and non-tensor data
|
||||||
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
||||||
|
labels = ["a", "b", "c"]
|
||||||
|
|
||||||
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
||||||
|
|
||||||
|
# list
|
||||||
|
repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2]))
|
||||||
|
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])
|
||||||
|
expected_labels_interleave = ["a", "a", "a", "b", "c", "c"]
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
|
||||||
|
assert repeated_data_interleave["labels"] == expected_labels_interleave
|
||||||
|
assert repeated_data_interleave["info"] == "test_info"
|
||||||
|
|
||||||
|
# torch.tensor
|
||||||
|
repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3]))
|
||||||
|
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])
|
||||||
|
expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"]
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
|
||||||
|
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
|
||||||
|
assert repeated_data_no_interleave["info"] == "test_info"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataproto_chunk_after_index():
|
||||||
|
data_len = 4
|
||||||
|
obs = torch.randn(data_len, 4)
|
||||||
|
labels = [f"label_{i}" for i in range(data_len)]
|
||||||
|
|
||||||
|
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"name": "abc"})
|
||||||
|
# Test with boolean numpy array
|
||||||
|
bool_mask = torch.tensor([True, False, True, False])
|
||||||
|
selected = data[bool_mask]
|
||||||
|
assert isinstance(selected.batch_size, torch.Size)
|
||||||
|
assert all(isinstance(d, int) for d in selected.batch_size) # int or List[int]
|
||||||
|
|
||||||
|
# Test with integer numpy array
|
||||||
|
int_mask = torch.tensor([0, 2])
|
||||||
|
selected = data[int_mask]
|
||||||
|
assert isinstance(selected.batch_size, torch.Size)
|
||||||
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
||||||
|
|
||||||
|
# Test with boolean list
|
||||||
|
list_mask = [True, False, True, False]
|
||||||
|
selected = data[list_mask]
|
||||||
|
assert isinstance(selected.batch_size, torch.Size)
|
||||||
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
||||||
|
|
||||||
|
# Test with list
|
||||||
|
list_mask = [0, 2]
|
||||||
|
selected = data[list_mask]
|
||||||
|
assert isinstance(selected.batch_size, torch.Size)
|
||||||
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
||||||
|
|
||||||
|
# Test with torch tensor (bool)
|
||||||
|
torch_bool_mask = torch.tensor([True, False, True, False])
|
||||||
|
selected = data[torch_bool_mask]
|
||||||
|
assert isinstance(selected.batch_size, torch.Size)
|
||||||
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
||||||
|
|
||||||
|
# Test with torch tensor (int)
|
||||||
|
torch_int_mask = torch.tensor([0, 2])
|
||||||
|
selected = data[torch_int_mask]
|
||||||
|
assert isinstance(selected.batch_size, torch.Size)
|
||||||
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
@ -74,6 +74,9 @@ if is_npu_available:
|
|||||||
# for third-party devices such as NPUs. This patch fixes this issue, and the relevant
|
# for third-party devices such as NPUs. This patch fixes this issue, and the relevant
|
||||||
# modifications can be removed once the fix is merged into tensordict.
|
# modifications can be removed once the fix is merged into tensordict.
|
||||||
|
|
||||||
|
import tensordict
|
||||||
|
|
||||||
|
if parse_version(tensordict.__version__) < parse_version("0.10.0"):
|
||||||
from tensordict.base import TensorDictBase
|
from tensordict.base import TensorDictBase
|
||||||
|
|
||||||
def _sync_all_patch(self):
|
def _sync_all_patch(self):
|
||||||
|
@ -31,6 +31,7 @@ import tensordict
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from packaging.version import parse as parse_version
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -42,6 +43,8 @@ __all__ = ["DataProto", "union_tensor_dict"]
|
|||||||
|
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
tensordict.set_lazy_legacy(False).set()
|
tensordict.set_lazy_legacy(False).set()
|
||||||
|
if parse_version(tensordict.__version__) < parse_version("0.10.0"):
|
||||||
|
tensordict.set_list_to_stack(True).set()
|
||||||
|
|
||||||
|
|
||||||
class _DataProtoConfigMeta(type):
|
class _DataProtoConfigMeta(type):
|
||||||
@ -964,6 +967,29 @@ class DataProto:
|
|||||||
meta_info=self.meta_info,
|
meta_info=self.meta_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_tensordict(self) -> TensorDict:
|
||||||
|
"""Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert parse_version(tensordict.__version__) >= parse_version("0.10"), (
|
||||||
|
"Convert DataProto to TensorDict at least requires tensordict version 0.10"
|
||||||
|
)
|
||||||
|
tensor_batch = self.batch.to_dict()
|
||||||
|
non_tensor_batch = self.non_tensor_batch
|
||||||
|
|
||||||
|
from verl.utils import tensordict_utils as tu
|
||||||
|
|
||||||
|
common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys())
|
||||||
|
assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}"
|
||||||
|
|
||||||
|
for key, val in non_tensor_batch.items():
|
||||||
|
assert isinstance(val, np.ndarray)
|
||||||
|
tensor_batch[key] = val.tolist()
|
||||||
|
output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info)
|
||||||
|
return output
|
||||||
|
|
||||||
def get_data_info(self) -> str:
|
def get_data_info(self) -> str:
|
||||||
"""Return formatted information about stored data with nested type details.
|
"""Return formatted information about stored data with nested type details.
|
||||||
|
|
||||||
|
186
verl/utils/tensordict_utils.py
Normal file
186
verl/utils/tensordict_utils.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from tensordict.tensorclass import NonTensorData, NonTensorStack
|
||||||
|
|
||||||
|
|
||||||
|
def assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict):
|
||||||
|
for key, val in non_tensor_dict.items():
|
||||||
|
assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
|
||||||
|
def assign_non_tensor_data(tensor_dict: TensorDict, key, val):
|
||||||
|
tensor_dict[key] = NonTensorData(val)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict:
|
||||||
|
meta_info:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if non_tensor_dict is None:
|
||||||
|
non_tensor_dict = {}
|
||||||
|
|
||||||
|
batch_size = None
|
||||||
|
|
||||||
|
for key, val in tensor_dict.items():
|
||||||
|
if isinstance(val, list):
|
||||||
|
for v in val:
|
||||||
|
assert not isinstance(v, torch.Tensor), (
|
||||||
|
"Passing a list makes the data NonTensorStack, "
|
||||||
|
"which doesn't support torch.Tensor. Please convert to numpy first"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(val, torch.Tensor | list)
|
||||||
|
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = len(val)
|
||||||
|
else:
|
||||||
|
assert len(val) == batch_size
|
||||||
|
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = []
|
||||||
|
else:
|
||||||
|
batch_size = [batch_size]
|
||||||
|
|
||||||
|
for key, val in non_tensor_dict.items():
|
||||||
|
assert key not in tensor_dict
|
||||||
|
tensor_dict[key] = NonTensorData(val)
|
||||||
|
|
||||||
|
return TensorDict(source=tensor_dict, batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
|
||||||
|
"""Union two tensordicts."""
|
||||||
|
assert tensor_dict1.batch_size == tensor_dict2.batch_size, (
|
||||||
|
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
|
||||||
|
)
|
||||||
|
for key in tensor_dict2.keys():
|
||||||
|
if key not in tensor_dict1.keys():
|
||||||
|
tensor_dict1[key] = tensor_dict2[key]
|
||||||
|
else:
|
||||||
|
if isinstance(tensor_dict2[key], torch.Tensor):
|
||||||
|
assert tensor_dict1[key].equal(tensor_dict2[key]), (
|
||||||
|
f"{key} in tensor_dict1 and tensor_dict2 are not the same object"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# non-tensor
|
||||||
|
assert tensor_dict1[key] == tensor_dict2[key], (
|
||||||
|
f"{key} in tensor_dict1 and tensor_dict2 are not the same object"
|
||||||
|
)
|
||||||
|
|
||||||
|
return tensor_dict1
|
||||||
|
|
||||||
|
|
||||||
|
def make_iterator(tensordict: TensorDict, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
assert tensordict.batch_size[0] % mini_batch_size == 0, f"{tensordict.batch_size[0]} % {mini_batch_size} != 0"
|
||||||
|
# we can directly create a dataloader from TensorDict
|
||||||
|
if dataloader_kwargs is None:
|
||||||
|
dataloader_kwargs = {}
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
assert isinstance(dataloader_kwargs, dict)
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
dataset=tensordict, batch_size=mini_batch_size, collate_fn=lambda x: x, generator=generator, **dataloader_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_data():
|
||||||
|
for _ in range(epochs):
|
||||||
|
yield from train_dataloader
|
||||||
|
|
||||||
|
return iter(get_data())
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):
|
||||||
|
assert set(tensordict1.keys()) == set(tensordict2.keys())
|
||||||
|
|
||||||
|
for key in tensordict1.keys():
|
||||||
|
val = tensordict1[key]
|
||||||
|
val2 = tensordict2[key]
|
||||||
|
|
||||||
|
assert type(val) is type(val2), f"The type of {key} must be the same. Got {type(val)} vs {type(val2)}"
|
||||||
|
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
assert torch.all(torch.eq(val, val2)).item()
|
||||||
|
else:
|
||||||
|
assert val == val2
|
||||||
|
|
||||||
|
|
||||||
|
def pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict:
|
||||||
|
tensor_output = {}
|
||||||
|
non_tensor_output = {}
|
||||||
|
for key in keys:
|
||||||
|
output = tensordict.get(key)
|
||||||
|
if isinstance(output, torch.Tensor):
|
||||||
|
tensor_output[key] = tensordict.pop(key)
|
||||||
|
elif isinstance(output, NonTensorStack):
|
||||||
|
tensor_output[key] = tensordict.pop(key).tolist()
|
||||||
|
else:
|
||||||
|
assert isinstance(output, NonTensorData)
|
||||||
|
non_tensor_output[key] = tensordict.pop(key)
|
||||||
|
|
||||||
|
return get_tensordict(tensor_output, non_tensor_output)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_to_divisor(data: TensorDict, size_divisor: int):
|
||||||
|
"""Pad a TensorDict to size divisible by size_divisor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_divisor (int): size divisor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data: (TensorDict): the padded TensorDict
|
||||||
|
pad_size (int)
|
||||||
|
"""
|
||||||
|
assert isinstance(data, TensorDict), "data must be a TensorDict"
|
||||||
|
if len(data) % size_divisor != 0:
|
||||||
|
pad_size = size_divisor - len(data) % size_divisor
|
||||||
|
padding_protos = []
|
||||||
|
remaining_pad = pad_size
|
||||||
|
while remaining_pad > 0:
|
||||||
|
take_size = min(remaining_pad, len(data))
|
||||||
|
padding_protos.append(data[:take_size])
|
||||||
|
remaining_pad -= take_size
|
||||||
|
data_padded = torch.cat([data] + padding_protos)
|
||||||
|
else:
|
||||||
|
if len(data) == 0:
|
||||||
|
logging.warning("padding a DataProto with no item, no changed made")
|
||||||
|
pad_size = 0
|
||||||
|
data_padded = data
|
||||||
|
return data_padded, pad_size
|
||||||
|
|
||||||
|
|
||||||
|
def unpad(data: TensorDict, pad_size):
|
||||||
|
"""Unpad the data proto with pad_size. i.e. `data[:-pad_size]`"""
|
||||||
|
if pad_size != 0:
|
||||||
|
data = data[:-pad_size]
|
||||||
|
return data
|
Reference in New Issue
Block a user