mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
### What does this PR do? This PR continues the work started in PR #2733, it adds support for variable sequence lengths in MultiTurnSFTDataset by introducing a `no_padding` option for the pad_mode. When this mode is active, sequences are not padded to a fixed length. - Implement no-padding mode for FSDP engine using nested tensors in sft trainer - Add test for no-padding mode both enable/disable use_remove_padding - Fix FSDP2 gradnorm issue ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: zhangchi.usc1992 <zhangchi.usc1992@bytedance.com>
596 lines
23 KiB
Python
596 lines
23 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Replace DataProto with raw TensorDict
|
|
"""
|
|
|
|
import copy
|
|
import random
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from verl.utils import tensordict_utils as tu
|
|
|
|
|
|
def test_union_tensor_dict():
|
|
obs = torch.randn(100, 10)
|
|
|
|
meta_info1 = {"top_p": 0.8}
|
|
meta_info2 = {"top_p": 0.9}
|
|
data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
|
|
data2 = {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100), "data_sources": ["gsm8k"] * 100}
|
|
|
|
data_with_copied_obs = {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}
|
|
|
|
data1 = tu.get_tensordict(tensor_dict=data1)
|
|
data2 = tu.get_tensordict(tensor_dict=data2)
|
|
data_with_copied_obs = tu.get_tensordict(data_with_copied_obs)
|
|
|
|
tu.union_tensor_dict(data1, data2)
|
|
with pytest.raises(AssertionError):
|
|
# conflict in tensor values
|
|
tu.union_tensor_dict(data1, data_with_copied_obs)
|
|
|
|
data1 = tu.assign_non_tensor_dict(data1, meta_info1)
|
|
tu.union_tensor_dict(data1, data2) # works ok
|
|
|
|
data2 = tu.assign_non_tensor_dict(data2, meta_info2)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# conflict in NonTensorData
|
|
tu.union_tensor_dict(data1, data2)
|
|
|
|
data1.pop("top_p")
|
|
data2.pop("top_p")
|
|
|
|
data2["data_sources"][0] = "math"
|
|
with pytest.raises(AssertionError):
|
|
# conflict in NonTensorData
|
|
tu.union_tensor_dict(data1, data2)
|
|
|
|
|
|
def test_tensor_dict_constructor():
|
|
obs = torch.ones(100, 10)
|
|
act = torch.zeros(100, 10, 3)
|
|
data_source = ["gsm8k"] * 100
|
|
non_tensor_dict = {"name": "abdce"}
|
|
|
|
data = tu.get_tensordict(
|
|
tensor_dict={"obs": obs, "act": act, "data_source": data_source}, non_tensor_dict=non_tensor_dict
|
|
)
|
|
|
|
assert data.batch_size == torch.Size([100])
|
|
|
|
# test slicing
|
|
assert torch.all(torch.eq(data[0]["obs"], torch.ones(10))).item()
|
|
assert torch.all(torch.eq(data[0]["act"], torch.zeros(10, 3))).item()
|
|
assert data[0]["data_source"] == "gsm8k"
|
|
|
|
assert torch.all(torch.eq(data[0:2]["obs"], torch.ones(2, 10))).item()
|
|
assert torch.all(torch.eq(data[0:2]["act"], torch.zeros(2, 10, 3))).item()
|
|
assert data[0:2]["data_source"] == ["gsm8k"] * 2
|
|
|
|
# test non tensor data
|
|
assert data["name"] == "abdce"
|
|
|
|
|
|
def test_index_select_tensor_dict():
|
|
vocab_size = 128
|
|
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
c = torch.randint(low=0, high=vocab_size, size=(12,))
|
|
d = torch.randint(low=0, high=vocab_size, size=(15,))
|
|
input_ids = [a, b, c, d]
|
|
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
|
|
|
padded_tensor = torch.randn(4, 10)
|
|
non_tensor_dict = {"global_batch_size": "4"}
|
|
|
|
data = tu.get_tensordict(
|
|
tensor_dict={
|
|
"input_ids": input_ids,
|
|
"padded_tensor": padded_tensor,
|
|
},
|
|
non_tensor_dict=non_tensor_dict,
|
|
)
|
|
|
|
assert data.batch_size == torch.Size([4])
|
|
|
|
# test index select
|
|
indices = torch.tensor([1, 3])
|
|
selected_data = tu.index_select_tensor_dict(data, indices)
|
|
|
|
assert selected_data.batch_size == torch.Size([2])
|
|
|
|
target_input_ids = torch.nested.as_nested_tensor([input_ids[idx] for idx in indices], layout=torch.jagged)
|
|
target_select_data = tu.get_tensordict(
|
|
tensor_dict={
|
|
"input_ids": target_input_ids,
|
|
"padded_tensor": padded_tensor[indices],
|
|
},
|
|
non_tensor_dict=non_tensor_dict,
|
|
)
|
|
tu.assert_tensordict_eq(selected_data, target_select_data)
|
|
|
|
|
|
def test_tensordict_with_images():
|
|
# each sample contains a sequence with multiple images of different sizes
|
|
vocab_size = 128
|
|
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
input_ids = [a, b]
|
|
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
|
|
|
# must be numpy
|
|
# TODO(vermouth1992). We may use nested tensor too. But this requires nested over nested
|
|
a_images = [
|
|
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
|
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
|
]
|
|
b_images = [
|
|
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
|
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
|
torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
|
|
]
|
|
|
|
images = [a_images, b_images]
|
|
|
|
data = tu.get_tensordict({"input_ids": input_ids, "images": images})
|
|
|
|
assert np.all(np.equal(data[0]["images"][0], a_images[0]))
|
|
assert torch.all(torch.eq(data[0]["input_ids"], a))
|
|
|
|
|
|
def test_tensordict_with_packing():
|
|
vocab_size = 128
|
|
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
input_ids = [a, b]
|
|
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
|
|
|
data = tu.get_tensordict({"input_ids": input_ids})
|
|
|
|
# test cu_seqlens
|
|
cu_seqlens = torch.tensor([0, 11, 24])
|
|
assert torch.all(torch.eq(cu_seqlens, data["input_ids"].offsets()))
|
|
|
|
# test index
|
|
assert torch.all(torch.eq(data["input_ids"][0], a))
|
|
assert torch.all(torch.eq(data["input_ids"][1], b))
|
|
|
|
assert torch.all(torch.eq(data[0]["input_ids"], a))
|
|
assert torch.all(torch.eq(data[1]["input_ids"], b))
|
|
|
|
data_lst = data.chunk(2)
|
|
|
|
assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
|
|
assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))
|
|
|
|
|
|
def test_tensordict_eq():
|
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
|
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
|
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
|
data = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
|
|
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
|
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
|
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
|
data1 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
|
|
|
tu.assert_tensordict_eq(data, data1)
|
|
|
|
data2 = copy.deepcopy(data1)
|
|
data2["obs"][0] += 1
|
|
|
|
with pytest.raises(AssertionError):
|
|
tu.assert_tensordict_eq(data, data2)
|
|
|
|
data2 = copy.deepcopy(data1)
|
|
data2["data_sources"][0] = "math"
|
|
|
|
with pytest.raises(AssertionError):
|
|
tu.assert_tensordict_eq(data, data2)
|
|
|
|
data2 = copy.deepcopy(data1)
|
|
data2["train_sample_kwargs"]["top_p"] = 0.9
|
|
|
|
with pytest.raises(AssertionError):
|
|
tu.assert_tensordict_eq(data, data2)
|
|
|
|
tensor_list = [
|
|
torch.tensor([1, 2, 3, 3, 2]),
|
|
torch.tensor([4, 5]),
|
|
torch.tensor([7, 8, 10, 14]),
|
|
torch.tensor([10, 11, 12]),
|
|
torch.tensor([13, 14, 15, 18]),
|
|
torch.tensor([16, 17]),
|
|
]
|
|
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
|
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
|
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
|
data3 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
|
|
|
tensor_list[0] = torch.tensor([1, 2, 3, 3, 2])
|
|
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
|
data4 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
|
tu.assert_tensordict_eq(data3, data4)
|
|
|
|
tensor_list[0] = torch.tensor([1, 2, 4])
|
|
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
|
data5 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
|
with pytest.raises(AssertionError):
|
|
tu.assert_tensordict_eq(data3, data5)
|
|
|
|
tensor_list[0] = torch.tensor([4, 5])
|
|
tensor_list[1] = torch.tensor([1, 2, 3, 3, 2])
|
|
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
|
|
data6 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
|
with pytest.raises(AssertionError):
|
|
tu.assert_tensordict_eq(data3, data6)
|
|
|
|
|
|
def test_tensor_dict_make_iterator():
|
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
|
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
|
|
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
|
|
dataset = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
|
|
|
|
dataloader = tu.make_iterator(
|
|
dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={"shuffle": False, "drop_last": False}
|
|
)
|
|
|
|
expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]]
|
|
|
|
i = 0
|
|
|
|
for d in dataloader:
|
|
tu.assert_tensordict_eq(d, expected_tensor_dict[i])
|
|
i += 1
|
|
|
|
data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
|
|
data_list_1 = []
|
|
for data in data_iter_1:
|
|
data_list_1.append(data)
|
|
|
|
data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
|
|
data_list_2 = []
|
|
for data in data_iter_2:
|
|
data_list_2.append(data)
|
|
|
|
for data1, data2 in zip(data_list_1, data_list_2, strict=True):
|
|
tu.assert_tensordict_eq(data1, data2)
|
|
|
|
|
|
def test_reorder():
|
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
|
labels = ["a", "b", "c", "d", "e", "f"]
|
|
non_tensor_dict = {"name": "abdce"}
|
|
|
|
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict=non_tensor_dict)
|
|
data = data[torch.tensor([3, 4, 2, 0, 1, 5])]
|
|
|
|
assert torch.all(torch.eq(data["obs"], torch.tensor([4, 5, 3, 1, 2, 6])))
|
|
assert np.all(data["labels"] == np.array(["d", "e", "c", "a", "b", "f"]))
|
|
assert data["name"] == "abdce"
|
|
|
|
|
|
def test_chunk_concat():
|
|
obs = torch.tensor([1, 2, 3, 4, 5, 6])
|
|
labels = ["a", "b", "c", "d", "e", "f"]
|
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"name": "abcde"})
|
|
|
|
data_split = data.tensor_split(indices_or_sections=5, dim=0)
|
|
|
|
expected_idx_lst = [[0, 1], [2], [3], [4], [5]]
|
|
|
|
for d, expected_idx in zip(data_split, expected_idx_lst, strict=False):
|
|
tu.assert_tensordict_eq(d, data[expected_idx])
|
|
|
|
data_split = data.chunk(2)
|
|
assert len(data_split) == 2
|
|
assert torch.all(torch.eq(data_split[0]["obs"], torch.tensor([1, 2, 3])))
|
|
assert np.all(data_split[0]["labels"] == np.array(["a", "b", "c"]))
|
|
assert data_split[0]["name"] == "abcde"
|
|
|
|
assert torch.all(torch.eq(data_split[1]["obs"], torch.tensor([4, 5, 6])))
|
|
assert np.all(data_split[1]["labels"] == np.array(["d", "e", "f"]))
|
|
assert data_split[1]["name"] == "abcde"
|
|
|
|
concat_data = torch.cat(data_split, dim=0)
|
|
assert torch.all(torch.eq(concat_data["obs"], data["obs"]))
|
|
assert np.all(concat_data["labels"] == data["labels"])
|
|
assert concat_data["name"] == data["name"]
|
|
|
|
|
|
def test_pop():
|
|
obs = torch.randn(100, 10)
|
|
act = torch.randn(100, 3)
|
|
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
|
|
|
|
poped_dataset = tu.pop(dataset, keys=["obs", "2"])
|
|
|
|
assert poped_dataset.batch_size[0] == 100
|
|
|
|
assert poped_dataset.keys() == {"obs", "2"}
|
|
|
|
assert dataset.keys() == {"act", "1"}
|
|
|
|
|
|
def test_repeat():
|
|
# Create a DataProto object with some batch and non-tensor data
|
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
|
labels = ["a", "b", "c"]
|
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
|
|
|
# Test interleave=True
|
|
repeated_data_interleave = data.repeat_interleave(repeats=2)
|
|
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])
|
|
expected_labels_interleave = ["a", "a", "b", "b", "c", "c"]
|
|
|
|
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
|
|
assert repeated_data_interleave["labels"] == expected_labels_interleave
|
|
assert repeated_data_interleave["info"] == "test_info"
|
|
|
|
# Test interleave=False
|
|
repeated_data_no_interleave = data.repeat(2)
|
|
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])
|
|
expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"]
|
|
|
|
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
|
|
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
|
|
assert repeated_data_no_interleave["info"] == "test_info"
|
|
|
|
|
|
def test_dataproto_pad_unpad():
|
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
|
labels = ["a", "b", "c"]
|
|
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
|
|
|
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2)
|
|
|
|
assert pad_size == 1
|
|
|
|
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])
|
|
expected_labels = ["a", "b", "c", "a"]
|
|
|
|
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
|
|
assert padded_data["labels"] == expected_labels
|
|
assert padded_data["info"] == "test_info"
|
|
|
|
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
|
|
assert torch.all(torch.eq(unpadd_data["obs"], obs))
|
|
assert unpadd_data["labels"] == labels
|
|
assert unpadd_data["info"] == "test_info"
|
|
|
|
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3)
|
|
assert pad_size == 0
|
|
|
|
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
|
expected_labels = ["a", "b", "c"]
|
|
|
|
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
|
|
assert padded_data["labels"] == expected_labels
|
|
assert padded_data["info"] == "test_info"
|
|
|
|
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
|
|
assert torch.all(torch.eq(unpadd_data["obs"], obs))
|
|
assert unpadd_data["labels"] == labels
|
|
assert unpadd_data["info"] == "test_info"
|
|
|
|
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7)
|
|
assert pad_size == 4
|
|
|
|
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
|
|
expected_labels = ["a", "b", "c", "a", "b", "c", "a"]
|
|
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
|
|
assert padded_data["labels"] == expected_labels
|
|
assert padded_data["info"] == "test_info"
|
|
|
|
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
|
|
assert torch.all(torch.eq(unpadd_data["obs"], obs))
|
|
assert unpadd_data["labels"] == labels
|
|
assert unpadd_data["info"] == "test_info"
|
|
|
|
|
|
def test_torch_save_data_proto():
|
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
|
labels = ["a", "b", "c"]
|
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
|
|
|
filename = "test_data.pt"
|
|
torch.save(data, filename)
|
|
loaded_data = torch.load(filename, weights_only=False)
|
|
|
|
assert torch.all(torch.eq(loaded_data["obs"], data["obs"]))
|
|
assert loaded_data["labels"] == data["labels"]
|
|
assert loaded_data["info"] == data["info"]
|
|
|
|
import os
|
|
|
|
os.remove(filename)
|
|
|
|
|
|
def test_len():
|
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
|
labels = np.array(["a", "b", "c"], dtype=object)
|
|
|
|
data = tu.get_tensordict({"obs": obs, "labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
|
|
assert len(data) == 3
|
|
|
|
data = tu.get_tensordict({"labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
|
|
assert len(data) == 3
|
|
|
|
data_item = data[0]
|
|
assert len(data_item) == 0
|
|
|
|
data = tu.get_tensordict({}, non_tensor_dict={"info": "test_info"})
|
|
assert len(data) == 0
|
|
|
|
|
|
def test_dataproto_index():
|
|
data_len = 100
|
|
idx_num = 10
|
|
|
|
obs = torch.randn(data_len, 10)
|
|
labels = [random.choice(["abc", "cde"]) for _ in range(data_len)]
|
|
|
|
data = tu.get_tensordict({"obs": obs, "labels": labels})
|
|
|
|
labels_np = np.array(labels)
|
|
|
|
idx_np_int = np.random.randint(0, data_len, size=(idx_num,))
|
|
result_np_int = data[idx_np_int]
|
|
assert result_np_int.keys() == data.keys()
|
|
assert result_np_int["obs"].shape[0] == idx_num
|
|
assert len(result_np_int["labels"]) == idx_num
|
|
assert np.array_equal(result_np_int["obs"].cpu().numpy(), obs[idx_np_int].numpy())
|
|
assert np.array_equal(result_np_int["labels"], labels_np[idx_np_int])
|
|
|
|
idx_torch_int = torch.randint(0, data_len, size=(idx_num,))
|
|
result_torch_int = data[idx_torch_int]
|
|
assert result_torch_int.keys() == data.keys()
|
|
assert result_torch_int["obs"].shape[0] == idx_num
|
|
assert len(result_torch_int["labels"]) == idx_num
|
|
assert np.array_equal(result_torch_int["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())
|
|
assert np.array_equal(result_torch_int["labels"], labels_np[idx_torch_int.cpu().numpy()])
|
|
|
|
idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]
|
|
result_list_int = data[idx_list_int]
|
|
assert result_list_int.keys() == data.keys()
|
|
assert result_list_int["obs"].shape[0] == idx_num
|
|
assert len(result_list_int["labels"]) == idx_num
|
|
assert np.array_equal(result_list_int["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy())
|
|
assert np.array_equal(result_list_int["labels"], labels_np[idx_list_int])
|
|
|
|
# idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)
|
|
# result_np_bool = data[idx_np_bool]
|
|
# assert result_np_bool.keys() == data.keys()
|
|
# assert result_np_bool["obs"].shape[0] == idx_np_bool.sum()
|
|
# assert len(result_np_bool["labels"]) == idx_np_bool.sum()
|
|
# assert np.array_equal(result_np_bool["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())
|
|
# assert np.array_equal(result_np_bool["labels"], labels_np[idx_np_bool])
|
|
|
|
idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)
|
|
result_torch_bool = data[idx_torch_bool]
|
|
assert result_torch_bool.keys() == data.keys()
|
|
assert result_torch_bool["obs"].shape[0] == idx_torch_bool.sum().item()
|
|
assert len(result_torch_bool["labels"]) == idx_torch_bool.sum().item()
|
|
assert np.array_equal(result_torch_bool["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())
|
|
assert np.array_equal(result_torch_bool["labels"], labels_np[idx_torch_bool])
|
|
|
|
# idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]
|
|
# result_list_bool = data[idx_list_bool]
|
|
# assert result_list_bool.keys() == data.keys()
|
|
# assert result_list_bool["obs"].shape[0] == sum(idx_list_bool)
|
|
# assert len(result_list_bool["labels"]) == sum(idx_list_bool)
|
|
# assert np.array_equal(result_list_bool["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())
|
|
# assert np.array_equal(result_list_bool["labels"], labels_np[idx_list_bool])
|
|
|
|
|
|
def test_select():
|
|
obs = torch.randn(100, 10)
|
|
act = torch.randn(100, 3)
|
|
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
|
|
|
|
subset = dataset.select("obs", "2")
|
|
|
|
assert torch.all(torch.eq(subset["obs"], dataset["obs"]))
|
|
assert subset["2"] == dataset["2"]
|
|
assert "act" not in subset.keys()
|
|
assert "1" not in subset.keys()
|
|
|
|
|
|
def test_dataproto_no_batch():
|
|
labels = ["a", "b", "c"]
|
|
data = tu.get_tensordict(tensor_dict={"labels": labels}, non_tensor_dict={"info": "test_info"})
|
|
selected = data.select("labels")
|
|
|
|
assert selected["labels"] == labels
|
|
pop_data = tu.pop(data, keys=["labels"])
|
|
assert pop_data["labels"] == labels
|
|
assert "labels" not in data
|
|
|
|
|
|
def test_sample_level_repeat():
|
|
# Create a DataProto object with some batch and non-tensor data
|
|
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
|
|
labels = ["a", "b", "c"]
|
|
|
|
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
|
|
|
|
# list
|
|
repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2]))
|
|
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])
|
|
expected_labels_interleave = ["a", "a", "a", "b", "c", "c"]
|
|
|
|
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
|
|
assert repeated_data_interleave["labels"] == expected_labels_interleave
|
|
assert repeated_data_interleave["info"] == "test_info"
|
|
|
|
# torch.tensor
|
|
repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3]))
|
|
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])
|
|
expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"]
|
|
|
|
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
|
|
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
|
|
assert repeated_data_no_interleave["info"] == "test_info"
|
|
|
|
|
|
def test_dataproto_chunk_after_index():
|
|
data_len = 4
|
|
obs = torch.randn(data_len, 4)
|
|
labels = [f"label_{i}" for i in range(data_len)]
|
|
|
|
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"name": "abc"})
|
|
# Test with boolean numpy array
|
|
bool_mask = torch.tensor([True, False, True, False])
|
|
selected = data[bool_mask]
|
|
assert isinstance(selected.batch_size, torch.Size)
|
|
assert all(isinstance(d, int) for d in selected.batch_size) # int or List[int]
|
|
|
|
# Test with integer numpy array
|
|
int_mask = torch.tensor([0, 2])
|
|
selected = data[int_mask]
|
|
assert isinstance(selected.batch_size, torch.Size)
|
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
|
|
|
# Test with boolean list
|
|
list_mask = [True, False, True, False]
|
|
selected = data[list_mask]
|
|
assert isinstance(selected.batch_size, torch.Size)
|
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
|
|
|
# Test with list
|
|
list_mask = [0, 2]
|
|
selected = data[list_mask]
|
|
assert isinstance(selected.batch_size, torch.Size)
|
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
|
|
|
# Test with torch tensor (bool)
|
|
torch_bool_mask = torch.tensor([True, False, True, False])
|
|
selected = data[torch_bool_mask]
|
|
assert isinstance(selected.batch_size, torch.Size)
|
|
assert all(isinstance(d, int) for d in selected.batch_size)
|
|
|
|
# Test with torch tensor (int)
|
|
torch_int_mask = torch.tensor([0, 2])
|
|
selected = data[torch_int_mask]
|
|
assert isinstance(selected.batch_size, torch.Size)
|
|
assert all(isinstance(d, int) for d in selected.batch_size)
|