mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[trainer] breaking: pass dataset as required args to SFTTrainer; also change ppo ray trainer to take custom datasets as inputs (#1282)
This commit is contained in:
@ -169,7 +169,7 @@ class RayPRIMETrainer(RayPPOTrainer):
|
||||
super()._validate_config()
|
||||
# TODO: Additional config checks can be added here
|
||||
|
||||
def _create_dataloader(self):
|
||||
def _create_dataloader(self, *args, **kwargs):
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
# TODO: we have to make sure the batch size is divisible by the dp size
|
||||
|
@ -103,7 +103,17 @@ def create_trainer(config):
|
||||
dp_size = world_size // config.ulysses_sequence_parallel_size
|
||||
ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))
|
||||
|
||||
return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
|
||||
# build tokenizer and datasets first
|
||||
from verl.trainer.fsdp_sft_trainer import create_sft_dataset
|
||||
from verl.utils import hf_tokenizer
|
||||
from verl.utils.fs import copy_to_local
|
||||
|
||||
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
|
||||
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
|
||||
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
|
||||
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
|
||||
|
||||
return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)
|
||||
|
||||
|
||||
def main(config):
|
||||
|
@ -37,7 +37,7 @@ from torch import nn, optim
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||
|
||||
@ -82,16 +82,12 @@ def convert_to_regular_types(obj):
|
||||
|
||||
|
||||
class FSDPSFTTrainer:
|
||||
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh):
|
||||
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh, tokenizer, train_dataset: Dataset, val_dataset: Dataset):
|
||||
self.config = config
|
||||
self.device_mesh = device_mesh
|
||||
self.ulysses_device_mesh = ulysses_device_mesh
|
||||
self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
|
||||
# build tokenizer first
|
||||
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
|
||||
self.tokenizer = tokenizer
|
||||
if self.config.data.chat_template is not None:
|
||||
raise ValueError("Apply Chat template from config is not supported yet.")
|
||||
|
||||
@ -105,7 +101,7 @@ class FSDPSFTTrainer:
|
||||
print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}")
|
||||
print(f"Using remove padding: {self.use_remove_padding}")
|
||||
|
||||
self._build_dataloader()
|
||||
self._build_dataloader(train_dataset, val_dataset)
|
||||
# build model
|
||||
self._build_model_optimizer()
|
||||
|
||||
@ -124,24 +120,10 @@ class FSDPSFTTrainer:
|
||||
|
||||
assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0
|
||||
|
||||
def _build_dataloader(self):
|
||||
config = self.config
|
||||
def _build_dataloader(self, train_dataset, val_dataset):
|
||||
# build dataset
|
||||
from verl.utils.import_utils import load_extern_type
|
||||
|
||||
# First check if a custom dataset class is specified
|
||||
if config.data.custom_cls.get("path", None):
|
||||
dataset_cls = load_extern_type(config.data.custom_cls.path, config.data.custom_cls.name)
|
||||
# Then check if multi-turn dataset should be used
|
||||
elif config.data.get("multiturn", {}).get("enable", False):
|
||||
dataset_cls = MultiTurnSFTDataset
|
||||
# Default to single-turn dataset
|
||||
else:
|
||||
dataset_cls = SFTDataset
|
||||
|
||||
# Create datasets based on the selected class
|
||||
self.train_dataset = dataset_cls(parquet_files=config.data.train_files, tokenizer=self.tokenizer, config=config.data)
|
||||
self.val_dataset = dataset_cls(parquet_files=config.data.val_files, tokenizer=self.tokenizer, config=config.data)
|
||||
config = self.config
|
||||
self.train_dataset, self.val_dataset = train_dataset, val_dataset
|
||||
|
||||
# build dataloader
|
||||
# Use data parallel rank and size instead of global rank and world size
|
||||
@ -525,9 +507,38 @@ def main(config):
|
||||
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
|
||||
dp_size = world_size // config.ulysses_sequence_parallel_size
|
||||
ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))
|
||||
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
|
||||
# build tokenizer and datasets first
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
|
||||
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
|
||||
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
|
||||
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
|
||||
|
||||
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)
|
||||
|
||||
trainer.fit()
|
||||
|
||||
|
||||
def create_sft_dataset(data_paths, data_config, tokenizer):
|
||||
"""Create a dataset."""
|
||||
# build dataset
|
||||
# First check if a custom dataset class is specified
|
||||
if data_config.custom_cls.get("path", None):
|
||||
from verl.utils.import_utils import load_extern_type
|
||||
|
||||
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
|
||||
# Then check if multi-turn dataset should be used
|
||||
elif data_config.get("multiturn", {}).get("enable", False):
|
||||
dataset_cls = MultiTurnSFTDataset
|
||||
# Default to single-turn dataset
|
||||
else:
|
||||
dataset_cls = SFTDataset
|
||||
|
||||
# Create datasets based on the selected class
|
||||
dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config)
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -163,6 +163,11 @@ class TaskRunner:
|
||||
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1)
|
||||
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
||||
|
||||
from verl.utils.dataset.rl_dataset import collate_fn
|
||||
|
||||
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
|
||||
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
|
||||
train_sampler = create_rl_sampler(config.data, train_dataset)
|
||||
trainer = RayPPOTrainer(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
@ -172,10 +177,73 @@ class TaskRunner:
|
||||
ray_worker_group_cls=ray_worker_group_cls,
|
||||
reward_fn=reward_fn,
|
||||
val_reward_fn=val_reward_fn,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
collate_fn=collate_fn,
|
||||
train_sampler=train_sampler,
|
||||
)
|
||||
trainer.init_workers()
|
||||
trainer.fit()
|
||||
|
||||
|
||||
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
|
||||
"""Create a dataset.
|
||||
|
||||
Arguments:
|
||||
data_config: The data config.
|
||||
tokenizer (Tokenizer): The tokenizer.
|
||||
processor (Processor): The processor.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset): The dataset.
|
||||
"""
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from verl.utils.dataset.rl_dataset import RLHFDataset
|
||||
|
||||
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
|
||||
from verl.utils.import_utils import load_extern_type
|
||||
|
||||
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
|
||||
if not issubclass(dataset_cls, Dataset):
|
||||
raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset")
|
||||
else:
|
||||
dataset_cls = RLHFDataset
|
||||
print(f"Using dataset class: {dataset_cls.__name__}")
|
||||
|
||||
dataset = dataset_cls(
|
||||
data_files=data_paths,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
config=data_config,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def create_rl_sampler(data_config, dataset):
|
||||
"""Create a sampler for the dataset.
|
||||
|
||||
Arguments:
|
||||
data_config: The data config.
|
||||
dataset (Dataset): The dataset.
|
||||
|
||||
Returns:
|
||||
sampler (Sampler): The sampler.
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import RandomSampler, SequentialSampler
|
||||
|
||||
# use sampler for better ckpt resume
|
||||
if data_config.shuffle:
|
||||
train_dataloader_generator = torch.Generator()
|
||||
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
|
||||
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
|
||||
else:
|
||||
sampler = SequentialSampler(data_source=dataset)
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -27,14 +27,14 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
import torch
|
||||
from codetiming import Timer
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from torch.utils.data import Dataset, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -54,7 +54,6 @@ from verl.trainer.ppo.metric_utils import (
|
||||
)
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
||||
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
||||
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
@ -279,6 +278,10 @@ class RayPPOTrainer:
|
||||
processor=None,
|
||||
reward_fn=None,
|
||||
val_reward_fn=None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
val_dataset: Optional[Dataset] = None,
|
||||
collate_fn=None,
|
||||
train_sampler: Optional[Sampler] = None,
|
||||
):
|
||||
# assert torch.cuda.is_available(), 'cuda must be available on driver'
|
||||
|
||||
@ -320,7 +323,7 @@ class RayPPOTrainer:
|
||||
raise NotImplementedError
|
||||
|
||||
self._validate_config()
|
||||
self._create_dataloader()
|
||||
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
||||
|
||||
def _validate_config(self):
|
||||
config = self.config
|
||||
@ -435,38 +438,25 @@ class RayPPOTrainer:
|
||||
|
||||
print("[validate_config] All configuration checks passed successfully!")
|
||||
|
||||
def _create_dataloader(self):
|
||||
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
|
||||
"""
|
||||
Creates the train and validation dataloaders.
|
||||
"""
|
||||
# make sure the batch size is divisible by the dp size
|
||||
from verl.utils.import_utils import load_extern_type
|
||||
# TODO: we have to make sure the batch size is divisible by the dp size
|
||||
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
|
||||
|
||||
if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None:
|
||||
# Dynamically load the custom dataset class specified in config
|
||||
try:
|
||||
dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name)
|
||||
if not issubclass(dataset_cls, Dataset):
|
||||
raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from '{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset")
|
||||
print(f"Using custom dataset class: {dataset_cls.__name__}")
|
||||
except Exception as e:
|
||||
print(f"Error loading custom dataset class: {e}")
|
||||
raise e
|
||||
else:
|
||||
dataset_cls = RLHFDataset
|
||||
print(f"Using default dataset class: {dataset_cls.__name__}")
|
||||
self.train_dataset = dataset_cls(
|
||||
data_files=self.config.data.train_files,
|
||||
tokenizer=self.tokenizer,
|
||||
processor=self.processor,
|
||||
config=self.config.data,
|
||||
)
|
||||
if self.config.data.shuffle:
|
||||
train_dataloader_generator = torch.Generator()
|
||||
train_dataloader_generator.manual_seed(self.config.data.get("seed", 1))
|
||||
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
|
||||
else:
|
||||
sampler = SequentialSampler(data_source=self.train_dataset)
|
||||
if train_dataset is None:
|
||||
train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor)
|
||||
if val_dataset is None:
|
||||
val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor)
|
||||
self.train_dataset, self.val_dataset = train_dataset, val_dataset
|
||||
|
||||
if train_sampler is None:
|
||||
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
|
||||
if collate_fn is None:
|
||||
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
|
||||
|
||||
collate_fn = default_collate_fn
|
||||
|
||||
self.train_dataloader = StatefulDataLoader(
|
||||
dataset=self.train_dataset,
|
||||
@ -474,14 +464,7 @@ class RayPPOTrainer:
|
||||
num_workers=self.config.data.get("dataloader_num_workers", 8),
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
self.val_dataset = dataset_cls(
|
||||
data_files=self.config.data.val_files,
|
||||
tokenizer=self.tokenizer,
|
||||
processor=self.processor,
|
||||
config=self.config.data,
|
||||
sampler=train_sampler,
|
||||
)
|
||||
|
||||
val_batch_size = self.config.data.val_batch_size # Prefer config value if set
|
||||
|
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def collate_fn(data_list: list[dict]) -> dict:
|
||||
"""Collate a batch of data."""
|
||||
tensors = defaultdict(list)
|
||||
non_tensors = defaultdict(list)
|
||||
|
||||
|
Reference in New Issue
Block a user