mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[misc] feat: support build DataProto from TensordDict (#3726)
### What does this PR do? Add a utility function to support building DataProto from TensorDict, which helps integrate TransferQueue into verl. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
@ -30,6 +30,7 @@ from verl.protocol import (
|
||||
union_numpy_dict,
|
||||
union_tensor_dict,
|
||||
)
|
||||
from verl.utils import tensordict_utils as tu
|
||||
|
||||
|
||||
def test_union_tensor_dict():
|
||||
@ -761,6 +762,23 @@ def test_to_tensordict():
|
||||
assert output["name"] == "abdce"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
|
||||
)
|
||||
def test_from_tensordict():
|
||||
tensor_dict = {
|
||||
"obs": torch.tensor([1, 2, 3, 4, 5, 6]),
|
||||
"labels": ["a", "b", "c", "d", "e", "f"],
|
||||
}
|
||||
non_tensor_dict = {"name": "abdce"}
|
||||
tensordict = tu.get_tensordict(tensor_dict, non_tensor_dict)
|
||||
data = DataProto.from_tensordict(tensordict)
|
||||
|
||||
assert data.non_tensor_batch["labels"].tolist() == tensor_dict["labels"]
|
||||
assert torch.all(torch.eq(data.batch["obs"], tensor_dict["obs"])).item()
|
||||
assert data.meta_info["name"] == "abdce"
|
||||
|
||||
|
||||
def test_serialize_deserialize_single_tensor():
|
||||
"""Test serialization and deserialization of a single tensor"""
|
||||
# Create test tensor
|
||||
|
@ -549,6 +549,47 @@ class DataProto:
|
||||
meta_info[DataProtoConfig.auto_padding_key] = True
|
||||
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
|
||||
|
||||
@classmethod
|
||||
def from_tensordict(
|
||||
cls,
|
||||
tensor_dict: TensorDict = None,
|
||||
meta_info=None,
|
||||
num_batch_dims=1,
|
||||
):
|
||||
"""Create a DataProto from a TensorDict. This assumes that
|
||||
1. All the tensor in tensor_dict have the same dim0
|
||||
2. Only dim0 is the batch dim
|
||||
"""
|
||||
assert version.parse(tensordict.__version__) >= version.parse("0.10.0"), (
|
||||
"Build DataProto from TensorDict at least requires tensordict version 0.10.0"
|
||||
)
|
||||
from tensordict import NonTensorData, NonTensorStack
|
||||
|
||||
assert num_batch_dims > 0, "num_batch_dims must be greater than zero"
|
||||
if not all(isinstance(val, torch.Tensor) for val in tensor_dict.values()):
|
||||
assert num_batch_dims == 1, "only support num_batch_dims=1 when tensor_dict contains non tensor data."
|
||||
|
||||
if meta_info is None:
|
||||
meta_info = {}
|
||||
batch = {}
|
||||
non_tensor_batch = {}
|
||||
batch_size = None
|
||||
for key, val in tensor_dict.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
batch[key] = val
|
||||
if batch_size is None:
|
||||
batch_size = val.shape[:num_batch_dims]
|
||||
elif isinstance(val, NonTensorStack):
|
||||
non_tensor_batch[key] = np.array([elem.data for elem in val], dtype=object)
|
||||
elif isinstance(val, NonTensorData):
|
||||
meta_info[key] = val.data
|
||||
|
||||
return cls(
|
||||
batch=TensorDict(batch, batch_size=batch_size),
|
||||
non_tensor_batch=non_tensor_batch,
|
||||
meta_info=meta_info,
|
||||
)
|
||||
|
||||
def to(self, device) -> "DataProto":
|
||||
"""move the batch to device
|
||||
|
||||
|
Reference in New Issue
Block a user