[trainer] breaking: pass dataset as required args to SFTTrainer; also change ppo ray trainer to take custom datasets as inputs (#1282)

This commit is contained in:
HL
2025-05-02 21:03:22 -07:00
committed by GitHub
parent cee3dca867
commit 52437be1a6
6 changed files with 141 additions and 68 deletions

View File

@ -169,7 +169,7 @@ class RayPRIMETrainer(RayPPOTrainer):
super()._validate_config()
# TODO: Additional config checks can be added here
def _create_dataloader(self):
def _create_dataloader(self, *args, **kwargs):
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
# TODO: we have to make sure the batch size is divisible by the dp size

View File

@ -103,7 +103,17 @@ def create_trainer(config):
dp_size = world_size // config.ulysses_sequence_parallel_size
ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))
return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
# build tokenizer and datasets first
from verl.trainer.fsdp_sft_trainer import create_sft_dataset
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)
def main(config):

View File

@ -37,7 +37,7 @@ from torch import nn, optim
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
@ -82,16 +82,12 @@ def convert_to_regular_types(obj):
class FSDPSFTTrainer:
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh):
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh, tokenizer, train_dataset: Dataset, val_dataset: Dataset):
self.config = config
self.device_mesh = device_mesh
self.ulysses_device_mesh = ulysses_device_mesh
self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# build tokenizer first
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
from verl.utils import hf_tokenizer
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
self.tokenizer = tokenizer
if self.config.data.chat_template is not None:
raise ValueError("Apply Chat template from config is not supported yet.")
@ -105,7 +101,7 @@ class FSDPSFTTrainer:
print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}")
print(f"Using remove padding: {self.use_remove_padding}")
self._build_dataloader()
self._build_dataloader(train_dataset, val_dataset)
# build model
self._build_model_optimizer()
@ -124,24 +120,10 @@ class FSDPSFTTrainer:
assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0
def _build_dataloader(self):
config = self.config
def _build_dataloader(self, train_dataset, val_dataset):
# build dataset
from verl.utils.import_utils import load_extern_type
# First check if a custom dataset class is specified
if config.data.custom_cls.get("path", None):
dataset_cls = load_extern_type(config.data.custom_cls.path, config.data.custom_cls.name)
# Then check if multi-turn dataset should be used
elif config.data.get("multiturn", {}).get("enable", False):
dataset_cls = MultiTurnSFTDataset
# Default to single-turn dataset
else:
dataset_cls = SFTDataset
# Create datasets based on the selected class
self.train_dataset = dataset_cls(parquet_files=config.data.train_files, tokenizer=self.tokenizer, config=config.data)
self.val_dataset = dataset_cls(parquet_files=config.data.val_files, tokenizer=self.tokenizer, config=config.data)
config = self.config
self.train_dataset, self.val_dataset = train_dataset, val_dataset
# build dataloader
# Use data parallel rank and size instead of global rank and world size
@ -525,9 +507,38 @@ def main(config):
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
dp_size = world_size // config.ulysses_sequence_parallel_size
ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
# build tokenizer and datasets first
from verl.utils import hf_tokenizer
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)
trainer.fit()
def create_sft_dataset(data_paths, data_config, tokenizer):
"""Create a dataset."""
# build dataset
# First check if a custom dataset class is specified
if data_config.custom_cls.get("path", None):
from verl.utils.import_utils import load_extern_type
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
# Then check if multi-turn dataset should be used
elif data_config.get("multiturn", {}).get("enable", False):
dataset_cls = MultiTurnSFTDataset
# Default to single-turn dataset
else:
dataset_cls = SFTDataset
# Create datasets based on the selected class
dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config)
return dataset
if __name__ == "__main__":
main()

View File

@ -163,6 +163,11 @@ class TaskRunner:
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
from verl.utils.dataset.rl_dataset import collate_fn
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
@ -172,10 +177,73 @@ class TaskRunner:
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
)
trainer.init_workers()
trainer.fit()
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
"""Create a dataset.
Arguments:
data_config: The data config.
tokenizer (Tokenizer): The tokenizer.
processor (Processor): The processor.
Returns:
dataset (Dataset): The dataset.
"""
from torch.utils.data import Dataset
from verl.utils.dataset.rl_dataset import RLHFDataset
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
from verl.utils.import_utils import load_extern_type
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset")
else:
dataset_cls = RLHFDataset
print(f"Using dataset class: {dataset_cls.__name__}")
dataset = dataset_cls(
data_files=data_paths,
tokenizer=tokenizer,
processor=processor,
config=data_config,
)
return dataset
def create_rl_sampler(data_config, dataset):
"""Create a sampler for the dataset.
Arguments:
data_config: The data config.
dataset (Dataset): The dataset.
Returns:
sampler (Sampler): The sampler.
"""
import torch
from torch.utils.data import RandomSampler, SequentialSampler
# use sampler for better ckpt resume
if data_config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=dataset)
return sampler
if __name__ == "__main__":
main()

View File

@ -27,14 +27,14 @@ from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Dict, Type
from typing import Dict, Optional, Type
import numpy as np
import ray
import torch
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from torch.utils.data import Dataset, RandomSampler, SequentialSampler
from torch.utils.data import Dataset, Sampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
@ -54,7 +54,6 @@ from verl.trainer.ppo.metric_utils import (
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
@ -279,6 +278,10 @@ class RayPPOTrainer:
processor=None,
reward_fn=None,
val_reward_fn=None,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
):
# assert torch.cuda.is_available(), 'cuda must be available on driver'
@ -320,7 +323,7 @@ class RayPPOTrainer:
raise NotImplementedError
self._validate_config()
self._create_dataloader()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def _validate_config(self):
config = self.config
@ -435,38 +438,25 @@ class RayPPOTrainer:
print("[validate_config] All configuration checks passed successfully!")
def _create_dataloader(self):
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
"""
Creates the train and validation dataloaders.
"""
# make sure the batch size is divisible by the dp size
from verl.utils.import_utils import load_extern_type
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None:
# Dynamically load the custom dataset class specified in config
try:
dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from '{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset")
print(f"Using custom dataset class: {dataset_cls.__name__}")
except Exception as e:
print(f"Error loading custom dataset class: {e}")
raise e
else:
dataset_cls = RLHFDataset
print(f"Using default dataset class: {dataset_cls.__name__}")
self.train_dataset = dataset_cls(
data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
)
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get("seed", 1))
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
if train_dataset is None:
train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor)
if val_dataset is None:
val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor)
self.train_dataset, self.val_dataset = train_dataset, val_dataset
if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
collate_fn = default_collate_fn
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
@ -474,14 +464,7 @@ class RayPPOTrainer:
num_workers=self.config.data.get("dataloader_num_workers", 8),
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
)
self.val_dataset = dataset_cls(
data_files=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
sampler=train_sampler,
)
val_batch_size = self.config.data.val_batch_size # Prefer config value if set

View File

@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
def collate_fn(data_list: list[dict]) -> dict:
"""Collate a batch of data."""
tensors = defaultdict(list)
non_tensors = defaultdict(list)