mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[trainer] refactor: PPO config validation fast fail (#3187)
### What does this PR do? Make main ppo script validate config as soon as all needed info is available. this enables the script to fail as fast as possible in case of bug in config. New changes would avoid downloading and loading tokenizer and loading data before validating config solve #3182 ### Design & Code Changes Isolated config validation in utils (out of PpoRayTrainer) and call it from main_ppo as soon as possible.
This commit is contained in:
@ -23,9 +23,12 @@ import hydra
|
||||
import ray
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from recipe.one_step_off_policy.utils import need_critic
|
||||
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
|
||||
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
|
||||
from verl.trainer.ppo.reward import load_reward_manager
|
||||
from verl.trainer.ppo.utils import need_reference_policy
|
||||
from verl.utils.config import validate_config
|
||||
|
||||
from .ray_trainer import OneStepOffRayTrainer
|
||||
|
||||
@ -87,20 +90,6 @@ class TaskRunner:
|
||||
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# Download the checkpoint from HDFS to the local machine.
|
||||
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
|
||||
local_path = copy_to_local(
|
||||
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
|
||||
)
|
||||
|
||||
# Instantiate the tokenizer and processor.
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
# Used for multimodal LLM, could be None
|
||||
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
|
||||
|
||||
# Define worker classes based on the actor strategy.
|
||||
if config.actor_rollout_ref.actor.strategy == "fsdp2":
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
@ -190,6 +179,27 @@ class TaskRunner:
|
||||
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
||||
mapping[Role.RefPolicy] = global_pool_id
|
||||
|
||||
# validate config
|
||||
validate_config(
|
||||
config=config,
|
||||
use_reference_policy=need_reference_policy(role_worker_mapping),
|
||||
use_critic=need_critic(config),
|
||||
)
|
||||
|
||||
# Download the checkpoint from HDFS to the local machine.
|
||||
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
|
||||
local_path = copy_to_local(
|
||||
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
|
||||
)
|
||||
|
||||
# Instantiate the tokenizer and processor.
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
# Used for multimodal LLM, could be None
|
||||
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
|
||||
|
||||
# Load the reward manager for training and validation.
|
||||
reward_fn = load_reward_manager(
|
||||
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
|
||||
|
@ -28,11 +28,12 @@ from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from recipe.one_step_off_policy.utils import need_critic
|
||||
from verl import DataProto
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
|
||||
from verl.single_controller.ray.base import create_colocated_worker_cls
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
|
||||
from verl.trainer.ppo.core_algos import agg_loss
|
||||
from verl.trainer.ppo.metric_utils import (
|
||||
compute_data_metrics,
|
||||
compute_throughout_metrics,
|
||||
@ -41,13 +42,12 @@ from verl.trainer.ppo.metric_utils import (
|
||||
from verl.trainer.ppo.ray_trainer import (
|
||||
RayPPOTrainer,
|
||||
ResourcePoolManager,
|
||||
Role,
|
||||
WorkerType,
|
||||
apply_kl_penalty,
|
||||
compute_advantage,
|
||||
compute_response_mask,
|
||||
)
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
|
||||
from verl.utils.debug import marked_timer
|
||||
from verl.utils.metric import (
|
||||
reduce_metrics,
|
||||
@ -140,8 +140,9 @@ class OneStepOffRayTrainer(RayPPOTrainer):
|
||||
|
||||
self.role_worker_mapping = role_worker_mapping
|
||||
self.resource_pool_manager = resource_pool_manager
|
||||
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
|
||||
self.use_rm = Role.RewardModel in role_worker_mapping
|
||||
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
|
||||
self.use_rm = need_reward_model(self.role_worker_mapping)
|
||||
self.use_critic = need_critic(config)
|
||||
self.ray_worker_group_cls = ray_worker_group_cls
|
||||
self.device_name = device_name
|
||||
self.validation_generations_logger = ValidationGenerationsLogger()
|
||||
@ -154,23 +155,6 @@ class OneStepOffRayTrainer(RayPPOTrainer):
|
||||
if config.algorithm.use_kl_in_reward:
|
||||
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
|
||||
|
||||
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
|
||||
self.use_critic = True
|
||||
elif self.config.algorithm.adv_estimator in [
|
||||
AdvantageEstimator.GRPO,
|
||||
AdvantageEstimator.GRPO_PASSK,
|
||||
AdvantageEstimator.REINFORCE_PLUS_PLUS,
|
||||
# AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy
|
||||
AdvantageEstimator.RLOO,
|
||||
AdvantageEstimator.OPO,
|
||||
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
|
||||
AdvantageEstimator.GPG,
|
||||
]:
|
||||
self.use_critic = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self._validate_config()
|
||||
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
||||
|
||||
def _validate(self):
|
||||
|
38
recipe/one_step_off_policy/utils.py
Normal file
38
recipe/one_step_off_policy/utils.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from verl.trainer.ppo.core_algos import AdvantageEstimator
|
||||
|
||||
|
||||
def need_critic(config: DictConfig) -> bool:
|
||||
"""Given a config, do we need critic"""
|
||||
if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
|
||||
return True
|
||||
elif config.algorithm.adv_estimator in [
|
||||
AdvantageEstimator.GRPO,
|
||||
AdvantageEstimator.GRPO_PASSK,
|
||||
AdvantageEstimator.REINFORCE_PLUS_PLUS,
|
||||
# AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy
|
||||
AdvantageEstimator.RLOO,
|
||||
AdvantageEstimator.OPO,
|
||||
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
|
||||
AdvantageEstimator.GPG,
|
||||
]:
|
||||
return False
|
||||
else:
|
||||
raise NotImplementedError
|
@ -33,6 +33,9 @@ import hydra
|
||||
import ray
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.trainer.ppo.utils import need_reference_policy
|
||||
from verl.utils.config import validate_config
|
||||
|
||||
from .prime_ray_trainer import RayPRIMETrainer
|
||||
|
||||
|
||||
@ -67,14 +70,6 @@ def main_task(config, compute_score=None):
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
# define worker classes
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
||||
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
||||
@ -118,6 +113,21 @@ def main_task(config, compute_score=None):
|
||||
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
|
||||
mapping[Role.RewardModel] = global_pool_id
|
||||
|
||||
# validate config
|
||||
# TODO: Additional config checks can be added with proper function under prime recipe
|
||||
validate_config(
|
||||
config=config,
|
||||
use_reference_policy=need_reference_policy(role_worker_mapping),
|
||||
use_critic=False,
|
||||
)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_tokenizer
|
||||
|
||||
tokenizer = hf_tokenizer(local_path)
|
||||
reward_manager_name = config.reward_model.get("reward_manager", "naive")
|
||||
if reward_manager_name == "naive":
|
||||
from verl.workers.reward_manager import NaiveRewardManager
|
||||
|
@ -30,7 +30,8 @@ from verl import DataProto
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
from verl.trainer.ppo.core_algos import agg_loss
|
||||
from verl.trainer.ppo.metric_utils import _compute_response_info
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager
|
||||
from verl.trainer.ppo.utils import Role, WorkerType
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
||||
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
|
||||
from verl.utils.metric import reduce_metrics
|
||||
@ -176,10 +177,6 @@ class RayPRIMETrainer(RayPPOTrainer):
|
||||
|
||||
self.use_critic = False
|
||||
|
||||
def _validate_config(self):
|
||||
super()._validate_config()
|
||||
# TODO: Additional config checks can be added here
|
||||
|
||||
def _create_dataloader(self, *args, **kwargs):
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
|
@ -19,7 +19,9 @@ import hydra
|
||||
import ray
|
||||
|
||||
from recipe.spin.spin_trainer import RaySPINTrainer
|
||||
from recipe.spin.utils import validate_config
|
||||
from verl.trainer.ppo.reward import get_custom_reward_fn
|
||||
from verl.trainer.ppo.utils import need_reference_policy
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="spin_trainer", version_base=None)
|
||||
@ -56,16 +58,6 @@ class TaskRunner:
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
||||
|
||||
# define worker classes
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
||||
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
||||
@ -117,6 +109,23 @@ class TaskRunner:
|
||||
role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker)
|
||||
mapping[Role.RefPolicy] = global_pool_id
|
||||
|
||||
# validate config
|
||||
validate_config(
|
||||
config=config,
|
||||
use_reference_policy=need_reference_policy(self.role_worker_mapping),
|
||||
use_critic=False,
|
||||
)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
||||
|
||||
from verl.workers.reward_manager import get_reward_manager_cls
|
||||
|
||||
# Note(haibin.lin): please make sure custom reward managers are imported and
|
||||
|
@ -19,7 +19,6 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -35,7 +34,6 @@ from tqdm import tqdm
|
||||
from recipe.spin import core_algos
|
||||
from verl import DataProto
|
||||
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
||||
from verl.single_controller.ray.base import create_colocated_worker_cls
|
||||
from verl.trainer.ppo.metric_utils import (
|
||||
@ -44,27 +42,12 @@ from verl.trainer.ppo.metric_utils import (
|
||||
process_validation_metrics,
|
||||
reduce_metrics,
|
||||
)
|
||||
from verl.trainer.ppo.ray_trainer import Role
|
||||
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
||||
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
|
||||
WorkerType = type[Worker]
|
||||
|
||||
|
||||
class AdvantageEstimator(str, Enum):
|
||||
"""
|
||||
Using an enumeration class to avoid spelling errors in adv_estimator
|
||||
"""
|
||||
|
||||
GAE = "gae"
|
||||
GRPO = "grpo"
|
||||
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
|
||||
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
|
||||
REMAX = "remax"
|
||||
RLOO = "rloo"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourcePoolManager:
|
||||
@ -386,8 +369,9 @@ class RaySPINTrainer:
|
||||
|
||||
self.role_worker_mapping = role_worker_mapping
|
||||
self.resource_pool_manager = resource_pool_manager
|
||||
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
|
||||
self.use_rm = Role.RewardModel in role_worker_mapping
|
||||
self.use_reference_policy = need_reference_policy(role_worker_mapping)
|
||||
self.use_rm = need_reward_model(role_worker_mapping)
|
||||
self.use_critic = False
|
||||
self.ray_worker_group_cls = ray_worker_group_cls
|
||||
self.validation_generations_logger = ValidationGenerationsLogger()
|
||||
self.async_rollout_mode = False
|
||||
@ -398,146 +382,8 @@ class RaySPINTrainer:
|
||||
if config.algorithm.use_kl_in_reward:
|
||||
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
|
||||
|
||||
self.use_critic = False
|
||||
self._validate_config()
|
||||
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
||||
|
||||
def _validate_config(self):
|
||||
config = self.config
|
||||
# number of GPUs total
|
||||
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
|
||||
|
||||
# 1. Check total batch size for data correctness
|
||||
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
|
||||
assert real_train_batch_size % n_gpus == 0, (
|
||||
f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
|
||||
)
|
||||
|
||||
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
|
||||
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
|
||||
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
|
||||
settings = {
|
||||
"actor_rollout_ref.actor": "micro_batch_size",
|
||||
"critic": "micro_batch_size",
|
||||
"reward_model": "micro_batch_size",
|
||||
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
|
||||
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
|
||||
}
|
||||
|
||||
if name in settings:
|
||||
param = settings[name]
|
||||
param_per_gpu = f"{param}_per_gpu"
|
||||
|
||||
if mbs is None and mbs_per_gpu is None:
|
||||
raise ValueError(
|
||||
f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'."
|
||||
)
|
||||
|
||||
if mbs is not None and mbs_per_gpu is not None:
|
||||
raise ValueError(
|
||||
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
|
||||
f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported "
|
||||
f"(the former is deprecated)."
|
||||
)
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.actor.ppo_micro_batch_size,
|
||||
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.actor",
|
||||
)
|
||||
|
||||
if self.use_reference_policy:
|
||||
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.ref",
|
||||
)
|
||||
|
||||
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.rollout",
|
||||
)
|
||||
|
||||
if self.use_critic and not config.critic.use_dynamic_bsz:
|
||||
# Check for critic micro-batch size conflicts
|
||||
check_mutually_exclusive(
|
||||
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
|
||||
)
|
||||
|
||||
# Check for reward model micro-batch size conflicts
|
||||
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
|
||||
check_mutually_exclusive(
|
||||
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
|
||||
)
|
||||
|
||||
# Actor
|
||||
# check if train_batch_size is larger than ppo_mini_batch_size
|
||||
# if NOT dynamic_bsz, we must ensure:
|
||||
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
|
||||
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
|
||||
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
|
||||
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
|
||||
assert (
|
||||
config.actor_rollout_ref.actor.ppo_mini_batch_size
|
||||
% config.actor_rollout_ref.actor.ppo_micro_batch_size
|
||||
== 0
|
||||
)
|
||||
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
|
||||
|
||||
assert config.actor_rollout_ref.actor.loss_agg_mode in [
|
||||
"token-mean",
|
||||
"seq-mean-token-sum",
|
||||
"seq-mean-token-mean",
|
||||
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
|
||||
|
||||
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
|
||||
print("NOTICE: You have both enabled in-reward kl and kl loss.")
|
||||
|
||||
# critic
|
||||
if self.use_critic and not config.critic.use_dynamic_bsz:
|
||||
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
|
||||
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
|
||||
if config.critic.ppo_micro_batch_size is not None:
|
||||
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
|
||||
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
|
||||
|
||||
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
||||
if (
|
||||
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
|
||||
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
|
||||
):
|
||||
assert config.actor_rollout_ref.model.use_remove_padding, (
|
||||
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
|
||||
)
|
||||
|
||||
if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
|
||||
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
|
||||
assert config.critic.model.use_remove_padding, (
|
||||
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
|
||||
)
|
||||
|
||||
if config.data.get("val_batch_size", None) is not None:
|
||||
print(
|
||||
"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines "
|
||||
"as a whole batch, which will schedule the memory themselves."
|
||||
)
|
||||
|
||||
# check eval config
|
||||
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
|
||||
assert config.actor_rollout_ref.rollout.temperature > 0, (
|
||||
"validation gen temperature should be greater than 0 when enabling do_sample"
|
||||
)
|
||||
|
||||
print("[validate_config] All configuration checks passed successfully!")
|
||||
|
||||
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
|
||||
"""
|
||||
Creates the train and validation dataloaders.
|
||||
|
160
recipe/spin/utils.py
Normal file
160
recipe/spin/utils.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def validate_config(
|
||||
config: DictConfig,
|
||||
use_reference_policy: bool,
|
||||
use_critic: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Validate an OmegaConf DictConfig
|
||||
|
||||
Args:
|
||||
config: The OmegaConf DictConfig to validate.
|
||||
use_reference_policy (bool): is ref policy needed
|
||||
use_critic (bool): is critic needed
|
||||
"""
|
||||
# number of GPUs total
|
||||
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
|
||||
|
||||
# 1. Check total batch size for data correctness
|
||||
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
|
||||
assert real_train_batch_size % n_gpus == 0, (
|
||||
f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
|
||||
)
|
||||
|
||||
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
|
||||
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
|
||||
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
|
||||
settings = {
|
||||
"actor_rollout_ref.actor": "micro_batch_size",
|
||||
"critic": "micro_batch_size",
|
||||
"reward_model": "micro_batch_size",
|
||||
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
|
||||
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
|
||||
}
|
||||
|
||||
if name in settings:
|
||||
param = settings[name]
|
||||
param_per_gpu = f"{param}_per_gpu"
|
||||
|
||||
if mbs is None and mbs_per_gpu is None:
|
||||
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
|
||||
|
||||
if mbs is not None and mbs_per_gpu is not None:
|
||||
raise ValueError(
|
||||
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
|
||||
f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported "
|
||||
f"(the former is deprecated)."
|
||||
)
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.actor.ppo_micro_batch_size,
|
||||
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.actor",
|
||||
)
|
||||
|
||||
if use_reference_policy:
|
||||
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.ref",
|
||||
)
|
||||
|
||||
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.rollout",
|
||||
)
|
||||
|
||||
if use_critic and not config.critic.use_dynamic_bsz:
|
||||
# Check for critic micro-batch size conflicts
|
||||
check_mutually_exclusive(
|
||||
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
|
||||
)
|
||||
|
||||
# Check for reward model micro-batch size conflicts
|
||||
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
|
||||
check_mutually_exclusive(
|
||||
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
|
||||
)
|
||||
|
||||
# Actor
|
||||
# check if train_batch_size is larger than ppo_mini_batch_size
|
||||
# if NOT dynamic_bsz, we must ensure:
|
||||
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
|
||||
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
|
||||
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
|
||||
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
|
||||
assert (
|
||||
config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size
|
||||
== 0
|
||||
)
|
||||
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
|
||||
|
||||
assert config.actor_rollout_ref.actor.loss_agg_mode in [
|
||||
"token-mean",
|
||||
"seq-mean-token-sum",
|
||||
"seq-mean-token-mean",
|
||||
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
|
||||
|
||||
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
|
||||
print("NOTICE: You have both enabled in-reward kl and kl loss.")
|
||||
|
||||
# critic
|
||||
if use_critic and not config.critic.use_dynamic_bsz:
|
||||
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
|
||||
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
|
||||
if config.critic.ppo_micro_batch_size is not None:
|
||||
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
|
||||
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
|
||||
|
||||
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
||||
if (
|
||||
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
|
||||
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
|
||||
):
|
||||
assert config.actor_rollout_ref.model.use_remove_padding, (
|
||||
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
|
||||
)
|
||||
|
||||
if use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
|
||||
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
|
||||
assert config.critic.model.use_remove_padding, (
|
||||
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
|
||||
)
|
||||
|
||||
if config.data.get("val_batch_size", None) is not None:
|
||||
print(
|
||||
"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines "
|
||||
"as a whole batch, which will schedule the memory themselves."
|
||||
)
|
||||
|
||||
# check eval config
|
||||
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
|
||||
assert config.actor_rollout_ref.rollout.temperature > 0, (
|
||||
"validation gen temperature should be greater than 0 when enabling do_sample"
|
||||
)
|
||||
|
||||
print("[validate_config] All configuration checks passed successfully!")
|
@ -24,6 +24,8 @@ import ray
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.trainer.ppo.reward import load_reward_manager
|
||||
from verl.trainer.ppo.utils import need_reference_policy
|
||||
from verl.utils.config import validate_config
|
||||
|
||||
from .sppo_ray_trainer import RaySPPOTrainer
|
||||
|
||||
@ -66,16 +68,6 @@ class TaskRunner:
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
||||
|
||||
# define worker classes
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
||||
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
||||
@ -133,6 +125,23 @@ class TaskRunner:
|
||||
role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker)
|
||||
mapping[Role.RefPolicy] = global_pool_id
|
||||
|
||||
# validate config
|
||||
validate_config(
|
||||
config=config,
|
||||
use_reference_policy=need_reference_policy(role_worker_mapping),
|
||||
use_critic=False,
|
||||
)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
||||
|
||||
reward_fn = load_reward_manager(
|
||||
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
|
||||
)
|
||||
|
@ -38,12 +38,11 @@ from verl.trainer.ppo.ray_trainer import (
|
||||
AdvantageEstimator,
|
||||
RayPPOTrainer,
|
||||
ResourcePoolManager,
|
||||
Role,
|
||||
WorkerType,
|
||||
apply_kl_penalty,
|
||||
compute_response_mask,
|
||||
)
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
|
||||
from verl.utils.profiler.performance import simple_timer
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
|
||||
@ -111,8 +110,9 @@ class RaySPPOTrainer(RayPPOTrainer):
|
||||
|
||||
self.role_worker_mapping = role_worker_mapping
|
||||
self.resource_pool_manager = resource_pool_manager
|
||||
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
|
||||
self.use_rm = Role.RewardModel in role_worker_mapping
|
||||
self.use_reference_policy = need_reference_policy(role_worker_mapping)
|
||||
self.use_rm = need_reward_model(role_worker_mapping)
|
||||
self.use_critic = False
|
||||
self.ray_worker_group_cls = ray_worker_group_cls
|
||||
self.validation_generations_logger = ValidationGenerationsLogger()
|
||||
self.device_name = device_name if device_name else self.config.trainer.device
|
||||
@ -122,9 +122,6 @@ class RaySPPOTrainer(RayPPOTrainer):
|
||||
if config.algorithm.use_kl_in_reward:
|
||||
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
|
||||
|
||||
self.use_critic = False
|
||||
|
||||
self._validate_config()
|
||||
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
||||
|
||||
def fit(self):
|
||||
|
@ -26,6 +26,8 @@ from verl.experimental.dataset.sampler import AbstractSampler
|
||||
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
||||
from verl.trainer.ppo.reward import load_reward_manager
|
||||
from verl.trainer.ppo.utils import need_critic, need_reference_policy
|
||||
from verl.utils.config import validate_config
|
||||
from verl.utils.device import is_cuda_available
|
||||
from verl.utils.import_utils import load_extern_type
|
||||
|
||||
@ -219,20 +221,6 @@ class TaskRunner:
|
||||
pprint(OmegaConf.to_container(config, resolve=True))
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# Download the checkpoint from HDFS to the local machine.
|
||||
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
|
||||
local_path = copy_to_local(
|
||||
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
|
||||
)
|
||||
|
||||
# Instantiate the tokenizer and processor.
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
# Used for multimodal LLM, could be None
|
||||
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
|
||||
|
||||
actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
|
||||
self.add_critic_worker(config)
|
||||
|
||||
@ -247,6 +235,27 @@ class TaskRunner:
|
||||
# Add a reference policy worker if KL loss or KL reward is used.
|
||||
self.add_ref_policy_worker(config, actor_rollout_cls)
|
||||
|
||||
# validate config
|
||||
validate_config(
|
||||
config=config,
|
||||
use_reference_policy=need_reference_policy(self.role_worker_mapping),
|
||||
use_critic=need_critic(config),
|
||||
)
|
||||
|
||||
# Download the checkpoint from HDFS to the local machine.
|
||||
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
|
||||
local_path = copy_to_local(
|
||||
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
|
||||
)
|
||||
|
||||
# Instantiate the tokenizer and processor.
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
# Used for multimodal LLM, could be None
|
||||
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
|
||||
|
||||
# Load the reward manager for training and validation.
|
||||
reward_fn = load_reward_manager(
|
||||
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
|
||||
|
@ -21,11 +21,9 @@ This trainer supports model-agonistic model initialization with huggingface
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
from typing import Optional
|
||||
|
||||
@ -40,7 +38,6 @@ from tqdm import tqdm
|
||||
from verl import DataProto
|
||||
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
|
||||
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
||||
from verl.single_controller.ray.base import create_colocated_worker_cls
|
||||
from verl.trainer.config import AlgoConfig
|
||||
@ -53,6 +50,7 @@ from verl.trainer.ppo.metric_utils import (
|
||||
process_validation_metrics,
|
||||
)
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
|
||||
from verl.utils.config import omega_conf_to_dataclass
|
||||
from verl.utils.debug import marked_timer
|
||||
@ -62,22 +60,6 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
|
||||
WorkerType = type[Worker]
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""
|
||||
To create more roles dynamically, you can subclass Role and add new members
|
||||
"""
|
||||
|
||||
Actor = 0
|
||||
Rollout = 1
|
||||
ActorRollout = 2
|
||||
Critic = 3
|
||||
RefPolicy = 4
|
||||
RewardModel = 5
|
||||
ActorRolloutRef = 6
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourcePoolManager:
|
||||
@ -352,8 +334,9 @@ class RayPPOTrainer:
|
||||
|
||||
self.role_worker_mapping = role_worker_mapping
|
||||
self.resource_pool_manager = resource_pool_manager
|
||||
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
|
||||
self.use_rm = Role.RewardModel in role_worker_mapping
|
||||
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
|
||||
self.use_rm = need_reward_model(self.role_worker_mapping)
|
||||
self.use_critic = need_critic(self.config)
|
||||
self.ray_worker_group_cls = ray_worker_group_cls
|
||||
self.device_name = device_name if device_name else self.config.trainer.device
|
||||
self.validation_generations_logger = ValidationGenerationsLogger(
|
||||
@ -369,138 +352,8 @@ class RayPPOTrainer:
|
||||
if self.config.algorithm.use_kl_in_reward:
|
||||
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
|
||||
|
||||
if config.critic.enable is not None:
|
||||
self.use_critic = bool(config.critic.enable)
|
||||
elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
|
||||
self.use_critic = True
|
||||
else:
|
||||
warnings.warn(
|
||||
"Disabled critic as algorithm.adv_estimator != gae. "
|
||||
"If it is not intended, please set critic.enable=True",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.use_critic = False
|
||||
|
||||
self._validate_config()
|
||||
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
||||
|
||||
def _validate_config(self):
|
||||
config = self.config
|
||||
# number of GPUs total
|
||||
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
if config.actor_rollout_ref.actor.strategy == "megatron":
|
||||
model_parallel_size = (
|
||||
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
|
||||
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
|
||||
)
|
||||
assert (
|
||||
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
|
||||
), (
|
||||
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
|
||||
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
|
||||
)
|
||||
megatron_dp = n_gpus // (
|
||||
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
|
||||
)
|
||||
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
|
||||
else:
|
||||
minimal_bsz = n_gpus
|
||||
|
||||
# 1. Check total batch size for data correctness
|
||||
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
|
||||
assert real_train_batch_size % minimal_bsz == 0, (
|
||||
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
|
||||
f"({minimal_bsz})"
|
||||
)
|
||||
|
||||
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
|
||||
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
|
||||
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
|
||||
"""Validate mutually exclusive micro batch size configuration options.
|
||||
|
||||
Ensures that users don't set both deprecated micro_batch_size and
|
||||
the new micro_batch_size_per_gpu parameters simultaneously.
|
||||
|
||||
Args:
|
||||
mbs: Deprecated micro batch size parameter value.
|
||||
mbs_per_gpu: New micro batch size per GPU parameter value.
|
||||
name (str): Configuration section name for error messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If both parameters are set or neither is set.
|
||||
"""
|
||||
settings = {
|
||||
"reward_model": "micro_batch_size",
|
||||
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
|
||||
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
|
||||
}
|
||||
|
||||
if name in settings:
|
||||
param = settings[name]
|
||||
param_per_gpu = f"{param}_per_gpu"
|
||||
|
||||
if mbs is None and mbs_per_gpu is None:
|
||||
raise ValueError(
|
||||
f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'."
|
||||
)
|
||||
|
||||
if mbs is not None and mbs_per_gpu is not None:
|
||||
raise ValueError(
|
||||
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
|
||||
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
|
||||
)
|
||||
|
||||
# Actor validation done in ActorConfig.__post_init__ and validate()
|
||||
actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
|
||||
actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
if self.use_reference_policy:
|
||||
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.ref",
|
||||
)
|
||||
|
||||
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.rollout",
|
||||
)
|
||||
|
||||
# Check for reward model micro-batch size conflicts
|
||||
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
|
||||
check_mutually_exclusive(
|
||||
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
|
||||
)
|
||||
|
||||
if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
|
||||
print("NOTICE: You have both enabled in-reward kl and kl loss.")
|
||||
|
||||
# critic
|
||||
if self.use_critic:
|
||||
critic_config = omega_conf_to_dataclass(config.critic)
|
||||
critic_config.validate(n_gpus, config.data.train_batch_size)
|
||||
|
||||
if config.data.get("val_batch_size", None) is not None:
|
||||
print(
|
||||
"WARNING: val_batch_size is deprecated."
|
||||
+ " Validation datasets are sent to inference engines as a whole batch,"
|
||||
+ " which will schedule the memory themselves."
|
||||
)
|
||||
|
||||
# check eval config
|
||||
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
|
||||
assert config.actor_rollout_ref.rollout.temperature > 0, (
|
||||
"validation gen temperature should be greater than 0 when enabling do_sample"
|
||||
)
|
||||
|
||||
print("[validate_config] All configuration checks passed successfully!")
|
||||
|
||||
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
|
||||
"""
|
||||
Creates the train and validation dataloaders.
|
||||
|
65
verl/trainer/ppo/utils.py
Normal file
65
verl/trainer/ppo/utils.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.trainer.ppo.core_algos import AdvantageEstimator
|
||||
|
||||
WorkerType = type[Worker]
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""
|
||||
To create more roles dynamically, you can subclass Role and add new members
|
||||
"""
|
||||
|
||||
Actor = 0
|
||||
Rollout = 1
|
||||
ActorRollout = 2
|
||||
Critic = 3
|
||||
RefPolicy = 4
|
||||
RewardModel = 5
|
||||
ActorRolloutRef = 6
|
||||
|
||||
|
||||
def need_reference_policy(
|
||||
role_worker_mapping: dict[Role, WorkerType],
|
||||
) -> bool:
|
||||
"""Given a role worker mapping, do we need ref policy."""
|
||||
return Role.RefPolicy in role_worker_mapping
|
||||
|
||||
|
||||
def need_reward_model(
|
||||
role_worker_mapping: dict[Role, WorkerType],
|
||||
) -> bool:
|
||||
"""Given a role worker mapping, do we need reward model."""
|
||||
return Role.RewardModel in role_worker_mapping
|
||||
|
||||
|
||||
def need_critic(config: DictConfig) -> bool:
|
||||
"""Given a config, do we need critic."""
|
||||
if config.critic.enable is not None:
|
||||
return bool(config.critic.enable)
|
||||
elif config.algorithm.adv_estimator == AdvantageEstimator.GAE:
|
||||
return True
|
||||
else:
|
||||
warnings.warn(
|
||||
"Disabled critic as algorithm.adv_estimator != gae. If it is not intended, please set critic.enable=True",
|
||||
stacklevel=2,
|
||||
)
|
||||
return False
|
@ -17,7 +17,7 @@ from typing import Any, Optional
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
|
||||
__all__ = ["omega_conf_to_dataclass"]
|
||||
__all__ = ["omega_conf_to_dataclass", "validate_config"]
|
||||
|
||||
|
||||
def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any:
|
||||
@ -69,3 +69,129 @@ def update_dict_with_config(dictionary: dict, config: DictConfig):
|
||||
for key in dictionary:
|
||||
if hasattr(config, key):
|
||||
dictionary[key] = getattr(config, key)
|
||||
|
||||
|
||||
def validate_config(
|
||||
config: DictConfig,
|
||||
use_reference_policy: bool,
|
||||
use_critic: bool,
|
||||
) -> None:
|
||||
"""Validate an OmegaConf DictConfig.
|
||||
|
||||
Args:
|
||||
config (DictConfig): The OmegaConf DictConfig to validate.
|
||||
use_reference_policy (bool): is ref policy needed
|
||||
use_critic (bool): is critic needed
|
||||
"""
|
||||
# number of GPUs total
|
||||
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
if config.actor_rollout_ref.actor.strategy == "megatron":
|
||||
model_parallel_size = (
|
||||
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
|
||||
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
|
||||
)
|
||||
assert (
|
||||
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
|
||||
), (
|
||||
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
|
||||
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
|
||||
)
|
||||
megatron_dp = n_gpus // (
|
||||
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
|
||||
)
|
||||
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
|
||||
else:
|
||||
minimal_bsz = n_gpus
|
||||
|
||||
# 1. Check total batch size for data correctness
|
||||
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
|
||||
assert real_train_batch_size % minimal_bsz == 0, (
|
||||
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
|
||||
f"({minimal_bsz})"
|
||||
)
|
||||
|
||||
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
|
||||
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
|
||||
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
|
||||
"""Validate mutually exclusive micro batch size configuration options.
|
||||
|
||||
Ensures that users don't set both deprecated micro_batch_size and
|
||||
the new micro_batch_size_per_gpu parameters simultaneously.
|
||||
|
||||
Args:
|
||||
mbs: Deprecated micro batch size parameter value.
|
||||
mbs_per_gpu: New micro batch size per GPU parameter value.
|
||||
name (str): Configuration section name for error messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If both parameters are set or neither is set.
|
||||
"""
|
||||
settings = {
|
||||
"reward_model": "micro_batch_size",
|
||||
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
|
||||
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
|
||||
}
|
||||
|
||||
if name in settings:
|
||||
param = settings[name]
|
||||
param_per_gpu = f"{param}_per_gpu"
|
||||
|
||||
if mbs is None and mbs_per_gpu is None:
|
||||
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
|
||||
|
||||
if mbs is not None and mbs_per_gpu is not None:
|
||||
raise ValueError(
|
||||
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
|
||||
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
|
||||
)
|
||||
|
||||
# Actor validation done in ActorConfig.__post_init__ and validate()
|
||||
actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
|
||||
actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)
|
||||
|
||||
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
||||
if use_reference_policy:
|
||||
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.ref",
|
||||
)
|
||||
|
||||
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
|
||||
check_mutually_exclusive(
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
|
||||
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
|
||||
"actor_rollout_ref.rollout",
|
||||
)
|
||||
|
||||
# Check for reward model micro-batch size conflicts
|
||||
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
|
||||
check_mutually_exclusive(
|
||||
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
|
||||
)
|
||||
|
||||
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
|
||||
print("NOTICE: You have both enabled in-reward kl and kl loss.")
|
||||
|
||||
# critic
|
||||
if use_critic:
|
||||
critic_config = omega_conf_to_dataclass(config.critic)
|
||||
critic_config.validate(n_gpus, config.data.train_batch_size)
|
||||
|
||||
if config.data.get("val_batch_size", None) is not None:
|
||||
print(
|
||||
"WARNING: val_batch_size is deprecated."
|
||||
+ " Validation datasets are sent to inference engines as a whole batch,"
|
||||
+ " which will schedule the memory themselves."
|
||||
)
|
||||
|
||||
# check eval config
|
||||
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
|
||||
assert config.actor_rollout_ref.rollout.temperature > 0, (
|
||||
"validation gen temperature should be greater than 0 when enabling do_sample"
|
||||
)
|
||||
|
||||
print("[validate_config] All configuration checks passed successfully!")
|
||||
|
Reference in New Issue
Block a user