mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[Breaking] dataset: support customized datasets for RayPPOTrainer (#924)
This PR enable user to specify their customized dataset for RayPPOTrainer. NOTE: the RLHFDataset interface has been broken into: ``` RLHFDataset( data_files: Union[str, List[str]], tokenizer: PreTrainedTokenizer, config: DictConfig, processor: Optional[ProcessorMixin] = None ) ``` and the custom dataset class MUST also use this interface. cc @eric-haibin-lin
This commit is contained in:
@ -21,12 +21,9 @@ then tokenize.
|
||||
|
||||
.. code:: python
|
||||
|
||||
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
|
||||
self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_key=self.config.data.prompt_key,
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
return_raw_chat=self.config.data.get('return_raw_chat', False),
|
||||
truncation='error')
|
||||
config=self.config.data)
|
||||
|
||||
Then, the dataloader will iterate the dataset under PPO mini batch size.
|
||||
|
||||
|
@ -15,6 +15,9 @@ data:
|
||||
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
|
||||
truncation: error
|
||||
image_key: images
|
||||
custom_cls:
|
||||
path: null
|
||||
name: null
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
|
@ -177,14 +177,9 @@ class RayPRIMETrainer(RayPPOTrainer):
|
||||
def _create_dataloader(self):
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
# TODO: we have to make sure the batch size is divisible by the dp size
|
||||
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
|
||||
self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_key=self.config.data.prompt_key,
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
return_raw_chat=self.config.data.get('return_raw_chat', False),
|
||||
truncation='error',
|
||||
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False),
|
||||
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
|
||||
config=self.config.data)
|
||||
# use sampler for better ckpt resume
|
||||
if self.config.data.shuffle:
|
||||
train_dataloader_generator = torch.Generator()
|
||||
@ -200,14 +195,9 @@ class RayPRIMETrainer(RayPPOTrainer):
|
||||
collate_fn=collate_fn,
|
||||
sampler=sampler)
|
||||
|
||||
self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
|
||||
self.val_dataset = RLHFDataset(data_files=self.config.data.val_files,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_key=self.config.data.prompt_key,
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
return_raw_chat=self.config.data.get('return_raw_chat', False),
|
||||
truncation='error',
|
||||
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False),
|
||||
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
|
||||
config=self.config.data)
|
||||
self.val_dataloader = DataLoader(dataset=self.val_dataset,
|
||||
batch_size=len(self.val_dataset),
|
||||
shuffle=True,
|
||||
|
@ -15,6 +15,7 @@ import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def get_gsm8k_data():
|
||||
@ -31,7 +32,13 @@ def test_rl_dataset():
|
||||
from verl.utils import hf_tokenizer
|
||||
tokenizer = hf_tokenizer('deepseek-ai/deepseek-coder-1.3b-instruct')
|
||||
local_path = get_gsm8k_data()
|
||||
dataset = RLHFDataset(parquet_files=local_path, tokenizer=tokenizer, prompt_key='prompt', max_prompt_length=256)
|
||||
config = OmegaConf.create({
|
||||
"prompt_key": "prompt",
|
||||
"max_prompt_length": 256,
|
||||
"filter_overlong_prompts": True,
|
||||
"filter_overlong_prompts_workers": 2,
|
||||
})
|
||||
dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config)
|
||||
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)
|
||||
|
||||
|
@ -15,6 +15,9 @@ data:
|
||||
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
|
||||
filter_overlong_prompts_workers: 1
|
||||
truncation: error
|
||||
custom_cls:
|
||||
path: null
|
||||
name: null
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
|
@ -15,6 +15,9 @@ data:
|
||||
filter_overlong_prompts_workers: 1
|
||||
truncation: error
|
||||
image_key: images
|
||||
custom_cls:
|
||||
path: null
|
||||
name: null
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
|
@ -18,6 +18,7 @@ This trainer supports model-agonistic model initialization with huggingface
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@ -43,7 +44,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
||||
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
from torch.utils.data import RandomSampler, SequentialSampler
|
||||
from torch.utils.data import Dataset, RandomSampler, SequentialSampler
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
WorkerType = Type[Worker]
|
||||
@ -408,19 +409,22 @@ class RayPPOTrainer(object):
|
||||
|
||||
def _create_dataloader(self):
|
||||
# TODO: we have to make sure the batch size is divisible by the dp size
|
||||
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
|
||||
tokenizer=self.tokenizer,
|
||||
processor=self.processor,
|
||||
prompt_key=self.config.data.prompt_key,
|
||||
image_key=self.config.data.get('image_key', 'images'),
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
return_raw_chat=self.config.data.get('return_raw_chat', False),
|
||||
truncation=self.config.data.get('truncation', 'error'),
|
||||
filter_overlong_prompts=self.config.data.filter_overlong_prompts,
|
||||
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
|
||||
assert self.train_dataset.truncation == self.config.data.get(
|
||||
'truncation', 'error'
|
||||
), f'dataset truncation {self.train_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
|
||||
from verl.utils.import_utils import load_extern_type
|
||||
if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None:
|
||||
dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name)
|
||||
if not issubclass(dataset_cls, Dataset):
|
||||
raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from "
|
||||
f"'{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset")
|
||||
else:
|
||||
dataset_cls = RLHFDataset
|
||||
|
||||
self.train_dataset = dataset_cls(
|
||||
data_files=self.config.data.train_files,
|
||||
tokenizer=self.tokenizer,
|
||||
processor=self.processor,
|
||||
config=self.config.data,
|
||||
)
|
||||
|
||||
# use sampler for better ckpt resume
|
||||
if self.config.data.shuffle:
|
||||
train_dataloader_generator = torch.Generator()
|
||||
@ -437,19 +441,12 @@ class RayPPOTrainer(object):
|
||||
collate_fn=collate_fn,
|
||||
sampler=sampler)
|
||||
|
||||
self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
|
||||
tokenizer=self.tokenizer,
|
||||
processor=self.processor,
|
||||
prompt_key=self.config.data.prompt_key,
|
||||
image_key=self.config.data.get('image_key', 'images'),
|
||||
max_prompt_length=self.config.data.max_prompt_length,
|
||||
return_raw_chat=self.config.data.get('return_raw_chat', False),
|
||||
truncation=self.config.data.get('truncation', 'error'),
|
||||
filter_overlong_prompts=self.config.data.filter_overlong_prompts,
|
||||
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
|
||||
assert self.val_dataset.truncation == self.config.data.get(
|
||||
'truncation', 'error'
|
||||
), f'dataset truncation {self.val_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
|
||||
self.val_dataset = dataset_cls(
|
||||
data_files=self.config.data.val_files,
|
||||
tokenizer=self.tokenizer,
|
||||
processor=self.processor,
|
||||
config=self.config.data,
|
||||
)
|
||||
self.val_dataloader = StatefulDataLoader(
|
||||
dataset=self.val_dataset,
|
||||
# Validation datasets are sent to inference engines as a whole batch,
|
||||
|
@ -12,9 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from omegaconf import ListConfig
|
||||
import os
|
||||
from typing import List, Union, Optional, Callable
|
||||
from typing import List, Union, Optional
|
||||
import copy
|
||||
import datasets
|
||||
from collections import defaultdict
|
||||
@ -23,6 +22,7 @@ import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from omegaconf import ListConfig, DictConfig
|
||||
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
import verl.utils.torch_functional as verl_F
|
||||
@ -77,40 +77,33 @@ class RLHFDataset(Dataset):
|
||||
We assume the dataset contains a column that contains prompts and other information
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
parquet_files: Union[str, List[str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
prompt_key: str = 'prompt',
|
||||
image_key: str = 'images',
|
||||
max_prompt_length: int = 1024,
|
||||
cache_dir: str = '~/.cache/verl/rlhf',
|
||||
chat_template_func: Optional[Callable] = None,
|
||||
return_raw_chat: bool = False,
|
||||
truncation: str = 'error',
|
||||
filter_overlong_prompts: bool = False,
|
||||
num_workers: Optional[int] = None):
|
||||
if not isinstance(parquet_files, (List, ListConfig)):
|
||||
parquet_files = [parquet_files]
|
||||
def __init__(
|
||||
self,
|
||||
data_files: Union[str, List[str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: DictConfig,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
):
|
||||
if not isinstance(data_files, (List, ListConfig)):
|
||||
data_files = [data_files]
|
||||
|
||||
self.parquet_files = copy.deepcopy(parquet_files)
|
||||
self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume
|
||||
self.cache_dir = os.path.expanduser(cache_dir)
|
||||
self.data_files = copy.deepcopy(data_files)
|
||||
self.original_data_files = copy.deepcopy(data_files) # use for resume
|
||||
self.tokenizer = tokenizer
|
||||
self.processor = processor
|
||||
self.config = config
|
||||
|
||||
self.prompt_key = prompt_key
|
||||
self.image_key = image_key
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
|
||||
self.prompt_key = config.get("prompt_key", "prompt")
|
||||
self.image_key = config.get("image_key", "images")
|
||||
self.max_prompt_length = config.get("max_prompt_length", 1024)
|
||||
|
||||
self.return_raw_chat = return_raw_chat
|
||||
self.chat_template_func = chat_template_func
|
||||
self.truncation = truncation
|
||||
self.filter_overlong_prompts = filter_overlong_prompts
|
||||
if num_workers is None:
|
||||
self.num_workers = max(1, os.cpu_count() // 4)
|
||||
else:
|
||||
self.num_workers = min(num_workers, os.cpu_count())
|
||||
self.return_raw_chat = config.get('return_raw_chat', False)
|
||||
self.truncation = config.get('truncation', 'error')
|
||||
self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
|
||||
|
||||
self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
|
||||
self.num_workers = min(self.num_workers, os.cpu_count())
|
||||
|
||||
# whether to store the dataset in state_dict()
|
||||
# default not store
|
||||
@ -120,13 +113,13 @@ class RLHFDataset(Dataset):
|
||||
|
||||
def _download(self, use_origin_parquet=False):
|
||||
from verl.utils.fs import copy_to_local
|
||||
parquet_files = self.parquet_files if not use_origin_parquet else self.original_parquet_files
|
||||
for i, parquet_file in enumerate(parquet_files):
|
||||
self.parquet_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir)
|
||||
data_files = self.data_files if not use_origin_parquet else self.original_data_files
|
||||
for i, parquet_file in enumerate(data_files):
|
||||
self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir)
|
||||
|
||||
def _read_files_and_tokenize(self):
|
||||
dataframes = []
|
||||
for parquet_file in self.parquet_files:
|
||||
for parquet_file in self.data_files:
|
||||
# read parquet files and cache
|
||||
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
|
||||
dataframes.append(dataframe)
|
||||
@ -147,7 +140,7 @@ class RLHFDataset(Dataset):
|
||||
print(f'filter dataset len: {len(self.dataframe)}')
|
||||
|
||||
def resume_dataset_state(self):
|
||||
self.serialize_dataset = False if hasattr(self, 'original_parquet_files') else True
|
||||
self.serialize_dataset = False if hasattr(self, 'original_data_files') else True
|
||||
# resume dataframe if not it's serialized in data.pt
|
||||
if not self.serialize_dataset:
|
||||
self._download(use_origin_parquet=True) # download and resume from original parquet files
|
||||
|
Reference in New Issue
Block a user