Files
verl/tests/test_protocol_v2_on_cpu.py
Houmin Wei 69b0127b74 [misc] feat: prototype deprecate DataProto and replace with Tensordict: part 2 (#3567)
### What does this PR do?

This PR continues the work started in PR #2733, it adds support for
variable sequence lengths in MultiTurnSFTDataset by introducing a
`no_padding` option for the pad_mode. When this mode is active,
sequences are not padded to a fixed length.
- Implement no-padding mode for FSDP engine using nested tensors in sft
trainer
- Add test for no-padding mode both enable/disable use_remove_padding
- Fix FSDP2 gradnorm issue

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: zhangchi.usc1992 <zhangchi.usc1992@bytedance.com>
2025-09-24 17:12:31 +08:00

596 lines
23 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Replace DataProto with raw TensorDict
"""
import copy
import random
import numpy as np
import pytest
import torch
from verl.utils import tensordict_utils as tu
def test_union_tensor_dict():
obs = torch.randn(100, 10)
meta_info1 = {"top_p": 0.8}
meta_info2 = {"top_p": 0.9}
data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
data2 = {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100), "data_sources": ["gsm8k"] * 100}
data_with_copied_obs = {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}
data1 = tu.get_tensordict(tensor_dict=data1)
data2 = tu.get_tensordict(tensor_dict=data2)
data_with_copied_obs = tu.get_tensordict(data_with_copied_obs)
tu.union_tensor_dict(data1, data2)
with pytest.raises(AssertionError):
# conflict in tensor values
tu.union_tensor_dict(data1, data_with_copied_obs)
data1 = tu.assign_non_tensor_dict(data1, meta_info1)
tu.union_tensor_dict(data1, data2) # works ok
data2 = tu.assign_non_tensor_dict(data2, meta_info2)
with pytest.raises(AssertionError):
# conflict in NonTensorData
tu.union_tensor_dict(data1, data2)
data1.pop("top_p")
data2.pop("top_p")
data2["data_sources"][0] = "math"
with pytest.raises(AssertionError):
# conflict in NonTensorData
tu.union_tensor_dict(data1, data2)
def test_tensor_dict_constructor():
obs = torch.ones(100, 10)
act = torch.zeros(100, 10, 3)
data_source = ["gsm8k"] * 100
non_tensor_dict = {"name": "abdce"}
data = tu.get_tensordict(
tensor_dict={"obs": obs, "act": act, "data_source": data_source}, non_tensor_dict=non_tensor_dict
)
assert data.batch_size == torch.Size([100])
# test slicing
assert torch.all(torch.eq(data[0]["obs"], torch.ones(10))).item()
assert torch.all(torch.eq(data[0]["act"], torch.zeros(10, 3))).item()
assert data[0]["data_source"] == "gsm8k"
assert torch.all(torch.eq(data[0:2]["obs"], torch.ones(2, 10))).item()
assert torch.all(torch.eq(data[0:2]["act"], torch.zeros(2, 10, 3))).item()
assert data[0:2]["data_source"] == ["gsm8k"] * 2
# test non tensor data
assert data["name"] == "abdce"
def test_index_select_tensor_dict():
vocab_size = 128
a = torch.randint(low=0, high=vocab_size, size=(11,))
b = torch.randint(low=0, high=vocab_size, size=(13,))
c = torch.randint(low=0, high=vocab_size, size=(12,))
d = torch.randint(low=0, high=vocab_size, size=(15,))
input_ids = [a, b, c, d]
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
padded_tensor = torch.randn(4, 10)
non_tensor_dict = {"global_batch_size": "4"}
data = tu.get_tensordict(
tensor_dict={
"input_ids": input_ids,
"padded_tensor": padded_tensor,
},
non_tensor_dict=non_tensor_dict,
)
assert data.batch_size == torch.Size([4])
# test index select
indices = torch.tensor([1, 3])
selected_data = tu.index_select_tensor_dict(data, indices)
assert selected_data.batch_size == torch.Size([2])
target_input_ids = torch.nested.as_nested_tensor([input_ids[idx] for idx in indices], layout=torch.jagged)
target_select_data = tu.get_tensordict(
tensor_dict={
"input_ids": target_input_ids,
"padded_tensor": padded_tensor[indices],
},
non_tensor_dict=non_tensor_dict,
)
tu.assert_tensordict_eq(selected_data, target_select_data)
def test_tensordict_with_images():
# each sample contains a sequence with multiple images of different sizes
vocab_size = 128
a = torch.randint(low=0, high=vocab_size, size=(11,))
b = torch.randint(low=0, high=vocab_size, size=(13,))
input_ids = [a, b]
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
# must be numpy
# TODO(vermouth1992). We may use nested tensor too. But this requires nested over nested
a_images = [
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
]
b_images = [
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
]
images = [a_images, b_images]
data = tu.get_tensordict({"input_ids": input_ids, "images": images})
assert np.all(np.equal(data[0]["images"][0], a_images[0]))
assert torch.all(torch.eq(data[0]["input_ids"], a))
def test_tensordict_with_packing():
vocab_size = 128
a = torch.randint(low=0, high=vocab_size, size=(11,))
b = torch.randint(low=0, high=vocab_size, size=(13,))
input_ids = [a, b]
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
data = tu.get_tensordict({"input_ids": input_ids})
# test cu_seqlens
cu_seqlens = torch.tensor([0, 11, 24])
assert torch.all(torch.eq(cu_seqlens, data["input_ids"].offsets()))
# test index
assert torch.all(torch.eq(data["input_ids"][0], a))
assert torch.all(torch.eq(data["input_ids"][1], b))
assert torch.all(torch.eq(data[0]["input_ids"], a))
assert torch.all(torch.eq(data[1]["input_ids"], b))
data_lst = data.chunk(2)
assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))
def test_tensordict_eq():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
data = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
obs = torch.tensor([1, 2, 3, 4, 5, 6])
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
data1 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
tu.assert_tensordict_eq(data, data1)
data2 = copy.deepcopy(data1)
data2["obs"][0] += 1
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data, data2)
data2 = copy.deepcopy(data1)
data2["data_sources"][0] = "math"
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data, data2)
data2 = copy.deepcopy(data1)
data2["train_sample_kwargs"]["top_p"] = 0.9
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data, data2)
tensor_list = [
torch.tensor([1, 2, 3, 3, 2]),
torch.tensor([4, 5]),
torch.tensor([7, 8, 10, 14]),
torch.tensor([10, 11, 12]),
torch.tensor([13, 14, 15, 18]),
torch.tensor([16, 17]),
]
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
data3 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
tensor_list[0] = torch.tensor([1, 2, 3, 3, 2])
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data4 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
tu.assert_tensordict_eq(data3, data4)
tensor_list[0] = torch.tensor([1, 2, 4])
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data5 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data3, data5)
tensor_list[0] = torch.tensor([4, 5])
tensor_list[1] = torch.tensor([1, 2, 3, 3, 2])
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
data6 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
with pytest.raises(AssertionError):
tu.assert_tensordict_eq(data3, data6)
def test_tensor_dict_make_iterator():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
dataset = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
dataloader = tu.make_iterator(
dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={"shuffle": False, "drop_last": False}
)
expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]]
i = 0
for d in dataloader:
tu.assert_tensordict_eq(d, expected_tensor_dict[i])
i += 1
data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
data_list_1 = []
for data in data_iter_1:
data_list_1.append(data)
data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True})
data_list_2 = []
for data in data_iter_2:
data_list_2.append(data)
for data1, data2 in zip(data_list_1, data_list_2, strict=True):
tu.assert_tensordict_eq(data1, data2)
def test_reorder():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
labels = ["a", "b", "c", "d", "e", "f"]
non_tensor_dict = {"name": "abdce"}
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict=non_tensor_dict)
data = data[torch.tensor([3, 4, 2, 0, 1, 5])]
assert torch.all(torch.eq(data["obs"], torch.tensor([4, 5, 3, 1, 2, 6])))
assert np.all(data["labels"] == np.array(["d", "e", "c", "a", "b", "f"]))
assert data["name"] == "abdce"
def test_chunk_concat():
obs = torch.tensor([1, 2, 3, 4, 5, 6])
labels = ["a", "b", "c", "d", "e", "f"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"name": "abcde"})
data_split = data.tensor_split(indices_or_sections=5, dim=0)
expected_idx_lst = [[0, 1], [2], [3], [4], [5]]
for d, expected_idx in zip(data_split, expected_idx_lst, strict=False):
tu.assert_tensordict_eq(d, data[expected_idx])
data_split = data.chunk(2)
assert len(data_split) == 2
assert torch.all(torch.eq(data_split[0]["obs"], torch.tensor([1, 2, 3])))
assert np.all(data_split[0]["labels"] == np.array(["a", "b", "c"]))
assert data_split[0]["name"] == "abcde"
assert torch.all(torch.eq(data_split[1]["obs"], torch.tensor([4, 5, 6])))
assert np.all(data_split[1]["labels"] == np.array(["d", "e", "f"]))
assert data_split[1]["name"] == "abcde"
concat_data = torch.cat(data_split, dim=0)
assert torch.all(torch.eq(concat_data["obs"], data["obs"]))
assert np.all(concat_data["labels"] == data["labels"])
assert concat_data["name"] == data["name"]
def test_pop():
obs = torch.randn(100, 10)
act = torch.randn(100, 3)
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
poped_dataset = tu.pop(dataset, keys=["obs", "2"])
assert poped_dataset.batch_size[0] == 100
assert poped_dataset.keys() == {"obs", "2"}
assert dataset.keys() == {"act", "1"}
def test_repeat():
# Create a DataProto object with some batch and non-tensor data
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
# Test interleave=True
repeated_data_interleave = data.repeat_interleave(repeats=2)
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])
expected_labels_interleave = ["a", "a", "b", "b", "c", "c"]
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
assert repeated_data_interleave["labels"] == expected_labels_interleave
assert repeated_data_interleave["info"] == "test_info"
# Test interleave=False
repeated_data_no_interleave = data.repeat(2)
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])
expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"]
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
assert repeated_data_no_interleave["info"] == "test_info"
def test_dataproto_pad_unpad():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2)
assert pad_size == 1
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ["a", "b", "c", "a"]
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
assert padded_data["labels"] == expected_labels
assert padded_data["info"] == "test_info"
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data["obs"], obs))
assert unpadd_data["labels"] == labels
assert unpadd_data["info"] == "test_info"
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3)
assert pad_size == 0
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
expected_labels = ["a", "b", "c"]
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
assert padded_data["labels"] == expected_labels
assert padded_data["info"] == "test_info"
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data["obs"], obs))
assert unpadd_data["labels"] == labels
assert unpadd_data["info"] == "test_info"
padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7)
assert pad_size == 4
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ["a", "b", "c", "a", "b", "c", "a"]
assert torch.all(torch.eq(padded_data["obs"], expected_obs))
assert padded_data["labels"] == expected_labels
assert padded_data["info"] == "test_info"
unpadd_data = tu.unpad(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data["obs"], obs))
assert unpadd_data["labels"] == labels
assert unpadd_data["info"] == "test_info"
def test_torch_save_data_proto():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
filename = "test_data.pt"
torch.save(data, filename)
loaded_data = torch.load(filename, weights_only=False)
assert torch.all(torch.eq(loaded_data["obs"], data["obs"]))
assert loaded_data["labels"] == data["labels"]
assert loaded_data["info"] == data["info"]
import os
os.remove(filename)
def test_len():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = np.array(["a", "b", "c"], dtype=object)
data = tu.get_tensordict({"obs": obs, "labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
assert len(data) == 3
data = tu.get_tensordict({"labels": labels.tolist()}, non_tensor_dict={"info": "test_info"})
assert len(data) == 3
data_item = data[0]
assert len(data_item) == 0
data = tu.get_tensordict({}, non_tensor_dict={"info": "test_info"})
assert len(data) == 0
def test_dataproto_index():
data_len = 100
idx_num = 10
obs = torch.randn(data_len, 10)
labels = [random.choice(["abc", "cde"]) for _ in range(data_len)]
data = tu.get_tensordict({"obs": obs, "labels": labels})
labels_np = np.array(labels)
idx_np_int = np.random.randint(0, data_len, size=(idx_num,))
result_np_int = data[idx_np_int]
assert result_np_int.keys() == data.keys()
assert result_np_int["obs"].shape[0] == idx_num
assert len(result_np_int["labels"]) == idx_num
assert np.array_equal(result_np_int["obs"].cpu().numpy(), obs[idx_np_int].numpy())
assert np.array_equal(result_np_int["labels"], labels_np[idx_np_int])
idx_torch_int = torch.randint(0, data_len, size=(idx_num,))
result_torch_int = data[idx_torch_int]
assert result_torch_int.keys() == data.keys()
assert result_torch_int["obs"].shape[0] == idx_num
assert len(result_torch_int["labels"]) == idx_num
assert np.array_equal(result_torch_int["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())
assert np.array_equal(result_torch_int["labels"], labels_np[idx_torch_int.cpu().numpy()])
idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]
result_list_int = data[idx_list_int]
assert result_list_int.keys() == data.keys()
assert result_list_int["obs"].shape[0] == idx_num
assert len(result_list_int["labels"]) == idx_num
assert np.array_equal(result_list_int["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy())
assert np.array_equal(result_list_int["labels"], labels_np[idx_list_int])
# idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)
# result_np_bool = data[idx_np_bool]
# assert result_np_bool.keys() == data.keys()
# assert result_np_bool["obs"].shape[0] == idx_np_bool.sum()
# assert len(result_np_bool["labels"]) == idx_np_bool.sum()
# assert np.array_equal(result_np_bool["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())
# assert np.array_equal(result_np_bool["labels"], labels_np[idx_np_bool])
idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)
result_torch_bool = data[idx_torch_bool]
assert result_torch_bool.keys() == data.keys()
assert result_torch_bool["obs"].shape[0] == idx_torch_bool.sum().item()
assert len(result_torch_bool["labels"]) == idx_torch_bool.sum().item()
assert np.array_equal(result_torch_bool["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())
assert np.array_equal(result_torch_bool["labels"], labels_np[idx_torch_bool])
# idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]
# result_list_bool = data[idx_list_bool]
# assert result_list_bool.keys() == data.keys()
# assert result_list_bool["obs"].shape[0] == sum(idx_list_bool)
# assert len(result_list_bool["labels"]) == sum(idx_list_bool)
# assert np.array_equal(result_list_bool["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())
# assert np.array_equal(result_list_bool["labels"], labels_np[idx_list_bool])
def test_select():
obs = torch.randn(100, 10)
act = torch.randn(100, 3)
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
subset = dataset.select("obs", "2")
assert torch.all(torch.eq(subset["obs"], dataset["obs"]))
assert subset["2"] == dataset["2"]
assert "act" not in subset.keys()
assert "1" not in subset.keys()
def test_dataproto_no_batch():
labels = ["a", "b", "c"]
data = tu.get_tensordict(tensor_dict={"labels": labels}, non_tensor_dict={"info": "test_info"})
selected = data.select("labels")
assert selected["labels"] == labels
pop_data = tu.pop(data, keys=["labels"])
assert pop_data["labels"] == labels
assert "labels" not in data
def test_sample_level_repeat():
# Create a DataProto object with some batch and non-tensor data
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = ["a", "b", "c"]
data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"})
# list
repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2]))
expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])
expected_labels_interleave = ["a", "a", "a", "b", "c", "c"]
assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave))
assert repeated_data_interleave["labels"] == expected_labels_interleave
assert repeated_data_interleave["info"] == "test_info"
# torch.tensor
repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3]))
expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])
expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"]
assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave))
assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave
assert repeated_data_no_interleave["info"] == "test_info"
def test_dataproto_chunk_after_index():
data_len = 4
obs = torch.randn(data_len, 4)
labels = [f"label_{i}" for i in range(data_len)]
data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"name": "abc"})
# Test with boolean numpy array
bool_mask = torch.tensor([True, False, True, False])
selected = data[bool_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size) # int or List[int]
# Test with integer numpy array
int_mask = torch.tensor([0, 2])
selected = data[int_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with boolean list
list_mask = [True, False, True, False]
selected = data[list_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with list
list_mask = [0, 2]
selected = data[list_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with torch tensor (bool)
torch_bool_mask = torch.tensor([True, False, True, False])
selected = data[torch_bool_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)
# Test with torch tensor (int)
torch_int_mask = torch.tensor([0, 2])
selected = data[torch_int_mask]
assert isinstance(selected.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch_size)