[Breaking] dataset: support customized datasets for RayPPOTrainer (#924)

This PR enable user to specify their customized dataset for
RayPPOTrainer.

NOTE: the RLHFDataset interface has been broken into:
```
RLHFDataset(
    data_files: Union[str, List[str]],
    tokenizer: PreTrainedTokenizer,
    config: DictConfig,
    processor: Optional[ProcessorMixin] = None
)
```

and the custom dataset class MUST also use this interface.

cc @eric-haibin-lin
This commit is contained in:
Qunhong Zeng
2025-04-11 13:07:42 +08:00
committed by GitHub
parent c9e3c57cf8
commit 3256142434
8 changed files with 76 additions and 83 deletions

View File

@ -21,12 +21,9 @@ then tokenize.
.. code:: python
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
prompt_key=self.config.data.prompt_key,
max_prompt_length=self.config.data.max_prompt_length,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error')
config=self.config.data)
Then, the dataloader will iterate the dataset under PPO mini batch size.

View File

@ -15,6 +15,9 @@ data:
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
truncation: error
image_key: images
custom_cls:
path: null
name: null
actor_rollout_ref:
hybrid_engine: True

View File

@ -177,14 +177,9 @@ class RayPRIMETrainer(RayPPOTrainer):
def _create_dataloader(self):
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
# TODO: we have to make sure the batch size is divisible by the dp size
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
prompt_key=self.config.data.prompt_key,
max_prompt_length=self.config.data.max_prompt_length,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error',
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False),
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
config=self.config.data)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
@ -200,14 +195,9 @@ class RayPRIMETrainer(RayPPOTrainer):
collate_fn=collate_fn,
sampler=sampler)
self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
self.val_dataset = RLHFDataset(data_files=self.config.data.val_files,
tokenizer=self.tokenizer,
prompt_key=self.config.data.prompt_key,
max_prompt_length=self.config.data.max_prompt_length,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error',
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False),
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
config=self.config.data)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,

View File

@ -15,6 +15,7 @@ import os
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from omegaconf import OmegaConf
def get_gsm8k_data():
@ -31,7 +32,13 @@ def test_rl_dataset():
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer('deepseek-ai/deepseek-coder-1.3b-instruct')
local_path = get_gsm8k_data()
dataset = RLHFDataset(parquet_files=local_path, tokenizer=tokenizer, prompt_key='prompt', max_prompt_length=256)
config = OmegaConf.create({
"prompt_key": "prompt",
"max_prompt_length": 256,
"filter_overlong_prompts": True,
"filter_overlong_prompts_workers": 2,
})
dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config)
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)

View File

@ -15,6 +15,9 @@ data:
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
filter_overlong_prompts_workers: 1
truncation: error
custom_cls:
path: null
name: null
actor_rollout_ref:
hybrid_engine: True

View File

@ -15,6 +15,9 @@ data:
filter_overlong_prompts_workers: 1
truncation: error
image_key: images
custom_cls:
path: null
name: null
actor_rollout_ref:
hybrid_engine: True

View File

@ -18,6 +18,7 @@ This trainer supports model-agonistic model initialization with huggingface
import os
import uuid
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
@ -43,7 +44,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.tracking import ValidationGenerationsLogger
from torch.utils.data import RandomSampler, SequentialSampler
from torch.utils.data import Dataset, RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
WorkerType = Type[Worker]
@ -408,19 +409,22 @@ class RayPPOTrainer(object):
def _create_dataloader(self):
# TODO: we have to make sure the batch size is divisible by the dp size
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
image_key=self.config.data.get('image_key', 'images'),
max_prompt_length=self.config.data.max_prompt_length,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation=self.config.data.get('truncation', 'error'),
filter_overlong_prompts=self.config.data.filter_overlong_prompts,
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
assert self.train_dataset.truncation == self.config.data.get(
'truncation', 'error'
), f'dataset truncation {self.train_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
from verl.utils.import_utils import load_extern_type
if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None:
dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from "
f"'{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset")
else:
dataset_cls = RLHFDataset
self.train_dataset = dataset_cls(
data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
@ -437,19 +441,12 @@ class RayPPOTrainer(object):
collate_fn=collate_fn,
sampler=sampler)
self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
image_key=self.config.data.get('image_key', 'images'),
max_prompt_length=self.config.data.max_prompt_length,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation=self.config.data.get('truncation', 'error'),
filter_overlong_prompts=self.config.data.filter_overlong_prompts,
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
assert self.val_dataset.truncation == self.config.data.get(
'truncation', 'error'
), f'dataset truncation {self.val_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
self.val_dataset = dataset_cls(
data_files=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
# Validation datasets are sent to inference engines as a whole batch,

View File

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from omegaconf import ListConfig
import os
from typing import List, Union, Optional, Callable
from typing import List, Union, Optional
import copy
import datasets
from collections import defaultdict
@ -23,6 +22,7 @@ import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from omegaconf import ListConfig, DictConfig
from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F
@ -77,40 +77,33 @@ class RLHFDataset(Dataset):
We assume the dataset contains a column that contains prompts and other information
"""
def __init__(self,
parquet_files: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin] = None,
prompt_key: str = 'prompt',
image_key: str = 'images',
max_prompt_length: int = 1024,
cache_dir: str = '~/.cache/verl/rlhf',
chat_template_func: Optional[Callable] = None,
return_raw_chat: bool = False,
truncation: str = 'error',
filter_overlong_prompts: bool = False,
num_workers: Optional[int] = None):
if not isinstance(parquet_files, (List, ListConfig)):
parquet_files = [parquet_files]
def __init__(
self,
data_files: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
config: DictConfig,
processor: Optional[ProcessorMixin] = None,
):
if not isinstance(data_files, (List, ListConfig)):
data_files = [data_files]
self.parquet_files = copy.deepcopy(parquet_files)
self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume
self.cache_dir = os.path.expanduser(cache_dir)
self.data_files = copy.deepcopy(data_files)
self.original_data_files = copy.deepcopy(data_files) # use for resume
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.prompt_key = prompt_key
self.image_key = image_key
self.max_prompt_length = max_prompt_length
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
self.prompt_key = config.get("prompt_key", "prompt")
self.image_key = config.get("image_key", "images")
self.max_prompt_length = config.get("max_prompt_length", 1024)
self.return_raw_chat = return_raw_chat
self.chat_template_func = chat_template_func
self.truncation = truncation
self.filter_overlong_prompts = filter_overlong_prompts
if num_workers is None:
self.num_workers = max(1, os.cpu_count() // 4)
else:
self.num_workers = min(num_workers, os.cpu_count())
self.return_raw_chat = config.get('return_raw_chat', False)
self.truncation = config.get('truncation', 'error')
self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
self.num_workers = min(self.num_workers, os.cpu_count())
# whether to store the dataset in state_dict()
# default not store
@ -120,13 +113,13 @@ class RLHFDataset(Dataset):
def _download(self, use_origin_parquet=False):
from verl.utils.fs import copy_to_local
parquet_files = self.parquet_files if not use_origin_parquet else self.original_parquet_files
for i, parquet_file in enumerate(parquet_files):
self.parquet_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir)
data_files = self.data_files if not use_origin_parquet else self.original_data_files
for i, parquet_file in enumerate(data_files):
self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir)
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.parquet_files:
for parquet_file in self.data_files:
# read parquet files and cache
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
dataframes.append(dataframe)
@ -147,7 +140,7 @@ class RLHFDataset(Dataset):
print(f'filter dataset len: {len(self.dataframe)}')
def resume_dataset_state(self):
self.serialize_dataset = False if hasattr(self, 'original_parquet_files') else True
self.serialize_dataset = False if hasattr(self, 'original_data_files') else True
# resume dataframe if not it's serialized in data.pt
if not self.serialize_dataset:
self._download(use_origin_parquet=True) # download and resume from original parquet files