[trainer] refactor: PPO config validation fast fail (#3187)

### What does this PR do?

Make main ppo script validate config as soon as all needed info is
available. this enables the script to fail as fast as possible in case
of bug in config.
New changes would avoid downloading and loading tokenizer and loading
data before validating config
solve #3182 

### Design & Code Changes

Isolated config validation in utils (out of PpoRayTrainer) and call it
from main_ppo as soon as possible.
This commit is contained in:
Slim Frikha
2025-08-26 06:31:39 +04:00
committed by GitHub
parent b4a410197c
commit 7592d69cbb
14 changed files with 513 additions and 400 deletions

View File

@ -23,9 +23,12 @@ import hydra
import ray
from omegaconf import OmegaConf
from recipe.one_step_off_policy.utils import need_critic
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl.trainer.ppo.reward import load_reward_manager
from verl.trainer.ppo.utils import need_reference_policy
from verl.utils.config import validate_config
from .ray_trainer import OneStepOffRayTrainer
@ -87,20 +90,6 @@ class TaskRunner:
OmegaConf.resolve(config)
# Download the checkpoint from HDFS to the local machine.
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# Define worker classes based on the actor strategy.
if config.actor_rollout_ref.actor.strategy == "fsdp2":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
@ -190,6 +179,27 @@ class TaskRunner:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
# validate config
validate_config(
config=config,
use_reference_policy=need_reference_policy(role_worker_mapping),
use_critic=need_critic(config),
)
# Download the checkpoint from HDFS to the local machine.
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# Load the reward manager for training and validation.
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})

View File

@ -28,11 +28,12 @@ from omegaconf import OmegaConf
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from recipe.one_step_off_policy.utils import need_critic
from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
@ -41,13 +42,12 @@ from verl.trainer.ppo.metric_utils import (
from verl.trainer.ppo.ray_trainer import (
RayPPOTrainer,
ResourcePoolManager,
Role,
WorkerType,
apply_kl_penalty,
compute_advantage,
compute_response_mask,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
from verl.utils.debug import marked_timer
from verl.utils.metric import (
reduce_metrics,
@ -140,8 +140,9 @@ class OneStepOffRayTrainer(RayPPOTrainer):
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
self.use_rm = need_reward_model(self.role_worker_mapping)
self.use_critic = need_critic(config)
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name
self.validation_generations_logger = ValidationGenerationsLogger()
@ -154,23 +155,6 @@ class OneStepOffRayTrainer(RayPPOTrainer):
if config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
elif self.config.algorithm.adv_estimator in [
AdvantageEstimator.GRPO,
AdvantageEstimator.GRPO_PASSK,
AdvantageEstimator.REINFORCE_PLUS_PLUS,
# AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy
AdvantageEstimator.RLOO,
AdvantageEstimator.OPO,
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
AdvantageEstimator.GPG,
]:
self.use_critic = False
else:
raise NotImplementedError
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def _validate(self):

View File

@ -0,0 +1,38 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from omegaconf import DictConfig
from verl.trainer.ppo.core_algos import AdvantageEstimator
def need_critic(config: DictConfig) -> bool:
"""Given a config, do we need critic"""
if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
return True
elif config.algorithm.adv_estimator in [
AdvantageEstimator.GRPO,
AdvantageEstimator.GRPO_PASSK,
AdvantageEstimator.REINFORCE_PLUS_PLUS,
# AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy
AdvantageEstimator.RLOO,
AdvantageEstimator.OPO,
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
AdvantageEstimator.GPG,
]:
return False
else:
raise NotImplementedError

View File

@ -33,6 +33,9 @@ import hydra
import ray
from omegaconf import OmegaConf
from verl.trainer.ppo.utils import need_reference_policy
from verl.utils.config import validate_config
from .prime_ray_trainer import RayPRIMETrainer
@ -67,14 +70,6 @@ def main_task(config, compute_score=None):
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
@ -118,6 +113,21 @@ def main_task(config, compute_score=None):
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# validate config
# TODO: Additional config checks can be added with proper function under prime recipe
validate_config(
config=config,
use_reference_policy=need_reference_policy(role_worker_mapping),
use_critic=False,
)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == "naive":
from verl.workers.reward_manager import NaiveRewardManager

View File

@ -30,7 +30,8 @@ from verl import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager
from verl.trainer.ppo.utils import Role, WorkerType
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.metric import reduce_metrics
@ -176,10 +177,6 @@ class RayPRIMETrainer(RayPPOTrainer):
self.use_critic = False
def _validate_config(self):
super()._validate_config()
# TODO: Additional config checks can be added here
def _create_dataloader(self, *args, **kwargs):
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

View File

@ -19,7 +19,9 @@ import hydra
import ray
from recipe.spin.spin_trainer import RaySPINTrainer
from recipe.spin.utils import validate_config
from verl.trainer.ppo.reward import get_custom_reward_fn
from verl.trainer.ppo.utils import need_reference_policy
@hydra.main(config_path="config", config_name="spin_trainer", version_base=None)
@ -56,16 +58,6 @@ class TaskRunner:
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
@ -117,6 +109,23 @@ class TaskRunner:
role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
# validate config
validate_config(
config=config,
use_reference_policy=need_reference_policy(self.role_worker_mapping),
use_critic=False,
)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
from verl.workers.reward_manager import get_reward_manager_cls
# Note(haibin.lin): please make sure custom reward managers are imported and

View File

@ -19,7 +19,6 @@ import uuid
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Any, Optional
@ -35,7 +34,6 @@ from tqdm import tqdm
from recipe.spin import core_algos
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo.metric_utils import (
@ -44,27 +42,12 @@ from verl.trainer.ppo.metric_utils import (
process_validation_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import Role
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
WorkerType = type[Worker]
class AdvantageEstimator(str, Enum):
"""
Using an enumeration class to avoid spelling errors in adv_estimator
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
@dataclass
class ResourcePoolManager:
@ -386,8 +369,9 @@ class RaySPINTrainer:
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.use_reference_policy = need_reference_policy(role_worker_mapping)
self.use_rm = need_reward_model(role_worker_mapping)
self.use_critic = False
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
self.async_rollout_mode = False
@ -398,146 +382,8 @@ class RaySPINTrainer:
if config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
self.use_critic = False
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def _validate_config(self):
config = self.config
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % n_gpus == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
settings = {
"actor_rollout_ref.actor": "micro_batch_size",
"critic": "micro_batch_size",
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
}
if name in settings:
param = settings[name]
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(
f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'."
)
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported "
f"(the former is deprecated)."
)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.actor.ppo_micro_batch_size,
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
"actor_rollout_ref.actor",
)
if self.use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.ref",
)
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.rollout",
)
if self.use_critic and not config.critic.use_dynamic_bsz:
# Check for critic micro-batch size conflicts
check_mutually_exclusive(
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
)
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)
# Actor
# check if train_batch_size is larger than ppo_mini_batch_size
# if NOT dynamic_bsz, we must ensure:
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert (
config.actor_rollout_ref.actor.ppo_mini_batch_size
% config.actor_rollout_ref.actor.ppo_micro_batch_size
== 0
)
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
assert config.actor_rollout_ref.actor.loss_agg_mode in [
"token-mean",
"seq-mean-token-sum",
"seq-mean-token-mean",
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print("NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if self.use_critic and not config.critic.use_dynamic_bsz:
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
if config.critic.ppo_micro_batch_size is not None:
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
if (
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
):
assert config.actor_rollout_ref.model.use_remove_padding, (
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
)
if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
assert config.critic.model.use_remove_padding, (
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
)
if config.data.get("val_batch_size", None) is not None:
print(
"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines "
"as a whole batch, which will schedule the memory themselves."
)
# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, (
"validation gen temperature should be greater than 0 when enabling do_sample"
)
print("[validate_config] All configuration checks passed successfully!")
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
"""
Creates the train and validation dataloaders.

160
recipe/spin/utils.py Normal file
View File

@ -0,0 +1,160 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from omegaconf import DictConfig
def validate_config(
config: DictConfig,
use_reference_policy: bool,
use_critic: bool,
) -> None:
"""
Validate an OmegaConf DictConfig
Args:
config: The OmegaConf DictConfig to validate.
use_reference_policy (bool): is ref policy needed
use_critic (bool): is critic needed
"""
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % n_gpus == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
settings = {
"actor_rollout_ref.actor": "micro_batch_size",
"critic": "micro_batch_size",
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
}
if name in settings:
param = settings[name]
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported "
f"(the former is deprecated)."
)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.actor.ppo_micro_batch_size,
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
"actor_rollout_ref.actor",
)
if use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.ref",
)
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.rollout",
)
if use_critic and not config.critic.use_dynamic_bsz:
# Check for critic micro-batch size conflicts
check_mutually_exclusive(
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
)
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)
# Actor
# check if train_batch_size is larger than ppo_mini_batch_size
# if NOT dynamic_bsz, we must ensure:
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert (
config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size
== 0
)
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
assert config.actor_rollout_ref.actor.loss_agg_mode in [
"token-mean",
"seq-mean-token-sum",
"seq-mean-token-mean",
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print("NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if use_critic and not config.critic.use_dynamic_bsz:
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
if config.critic.ppo_micro_batch_size is not None:
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
if (
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
):
assert config.actor_rollout_ref.model.use_remove_padding, (
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
)
if use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
assert config.critic.model.use_remove_padding, (
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
)
if config.data.get("val_batch_size", None) is not None:
print(
"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines "
"as a whole batch, which will schedule the memory themselves."
)
# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, (
"validation gen temperature should be greater than 0 when enabling do_sample"
)
print("[validate_config] All configuration checks passed successfully!")

View File

@ -24,6 +24,8 @@ import ray
from omegaconf import OmegaConf
from verl.trainer.ppo.reward import load_reward_manager
from verl.trainer.ppo.utils import need_reference_policy
from verl.utils.config import validate_config
from .sppo_ray_trainer import RaySPPOTrainer
@ -66,16 +68,6 @@ class TaskRunner:
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
@ -133,6 +125,23 @@ class TaskRunner:
role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
# validate config
validate_config(
config=config,
use_reference_policy=need_reference_policy(role_worker_mapping),
use_critic=False,
)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)

View File

@ -38,12 +38,11 @@ from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
RayPPOTrainer,
ResourcePoolManager,
Role,
WorkerType,
apply_kl_penalty,
compute_response_mask,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
from verl.utils.profiler.performance import simple_timer
from verl.utils.tracking import ValidationGenerationsLogger
@ -111,8 +110,9 @@ class RaySPPOTrainer(RayPPOTrainer):
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.use_reference_policy = need_reference_policy(role_worker_mapping)
self.use_rm = need_reward_model(role_worker_mapping)
self.use_critic = False
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
self.device_name = device_name if device_name else self.config.trainer.device
@ -122,9 +122,6 @@ class RaySPPOTrainer(RayPPOTrainer):
if config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
self.use_critic = False
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def fit(self):

View File

@ -26,6 +26,8 @@ from verl.experimental.dataset.sampler import AbstractSampler
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.reward import load_reward_manager
from verl.trainer.ppo.utils import need_critic, need_reference_policy
from verl.utils.config import validate_config
from verl.utils.device import is_cuda_available
from verl.utils.import_utils import load_extern_type
@ -219,20 +221,6 @@ class TaskRunner:
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
# Download the checkpoint from HDFS to the local machine.
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
self.add_critic_worker(config)
@ -247,6 +235,27 @@ class TaskRunner:
# Add a reference policy worker if KL loss or KL reward is used.
self.add_ref_policy_worker(config, actor_rollout_cls)
# validate config
validate_config(
config=config,
use_reference_policy=need_reference_policy(self.role_worker_mapping),
use_critic=need_critic(config),
)
# Download the checkpoint from HDFS to the local machine.
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# Load the reward manager for training and validation.
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})

View File

@ -21,11 +21,9 @@ This trainer supports model-agonistic model initialization with huggingface
import json
import os
import uuid
import warnings
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Optional
@ -40,7 +38,6 @@ from tqdm import tqdm
from verl import DataProto
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.config import AlgoConfig
@ -53,6 +50,7 @@ from verl.trainer.ppo.metric_utils import (
process_validation_metrics,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
@ -62,22 +60,6 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
WorkerType = type[Worker]
class Role(Enum):
"""
To create more roles dynamically, you can subclass Role and add new members
"""
Actor = 0
Rollout = 1
ActorRollout = 2
Critic = 3
RefPolicy = 4
RewardModel = 5
ActorRolloutRef = 6
@dataclass
class ResourcePoolManager:
@ -352,8 +334,9 @@ class RayPPOTrainer:
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
self.use_rm = need_reward_model(self.role_worker_mapping)
self.use_critic = need_critic(self.config)
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device
self.validation_generations_logger = ValidationGenerationsLogger(
@ -369,138 +352,8 @@ class RayPPOTrainer:
if self.config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
if config.critic.enable is not None:
self.use_critic = bool(config.critic.enable)
elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
else:
warnings.warn(
"Disabled critic as algorithm.adv_estimator != gae. "
"If it is not intended, please set critic.enable=True",
stacklevel=2,
)
self.use_critic = False
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def _validate_config(self):
config = self.config
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if config.actor_rollout_ref.actor.strategy == "megatron":
model_parallel_size = (
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
)
assert (
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
), (
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % minimal_bsz == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
f"({minimal_bsz})"
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
"""Validate mutually exclusive micro batch size configuration options.
Ensures that users don't set both deprecated micro_batch_size and
the new micro_batch_size_per_gpu parameters simultaneously.
Args:
mbs: Deprecated micro batch size parameter value.
mbs_per_gpu: New micro batch size per GPU parameter value.
name (str): Configuration section name for error messages.
Raises:
ValueError: If both parameters are set or neither is set.
"""
settings = {
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
}
if name in settings:
param = settings[name]
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(
f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'."
)
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)
# Actor validation done in ActorConfig.__post_init__ and validate()
actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if self.use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.ref",
)
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.rollout",
)
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)
if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print("NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if self.use_critic:
critic_config = omega_conf_to_dataclass(config.critic)
critic_config.validate(n_gpus, config.data.train_batch_size)
if config.data.get("val_batch_size", None) is not None:
print(
"WARNING: val_batch_size is deprecated."
+ " Validation datasets are sent to inference engines as a whole batch,"
+ " which will schedule the memory themselves."
)
# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, (
"validation gen temperature should be greater than 0 when enabling do_sample"
)
print("[validate_config] All configuration checks passed successfully!")
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
"""
Creates the train and validation dataloaders.

65
verl/trainer/ppo/utils.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from enum import Enum
from omegaconf import DictConfig
from verl.single_controller.base import Worker
from verl.trainer.ppo.core_algos import AdvantageEstimator
WorkerType = type[Worker]
class Role(Enum):
"""
To create more roles dynamically, you can subclass Role and add new members
"""
Actor = 0
Rollout = 1
ActorRollout = 2
Critic = 3
RefPolicy = 4
RewardModel = 5
ActorRolloutRef = 6
def need_reference_policy(
role_worker_mapping: dict[Role, WorkerType],
) -> bool:
"""Given a role worker mapping, do we need ref policy."""
return Role.RefPolicy in role_worker_mapping
def need_reward_model(
role_worker_mapping: dict[Role, WorkerType],
) -> bool:
"""Given a role worker mapping, do we need reward model."""
return Role.RewardModel in role_worker_mapping
def need_critic(config: DictConfig) -> bool:
"""Given a config, do we need critic."""
if config.critic.enable is not None:
return bool(config.critic.enable)
elif config.algorithm.adv_estimator == AdvantageEstimator.GAE:
return True
else:
warnings.warn(
"Disabled critic as algorithm.adv_estimator != gae. If it is not intended, please set critic.enable=True",
stacklevel=2,
)
return False

View File

@ -17,7 +17,7 @@ from typing import Any, Optional
from omegaconf import DictConfig, ListConfig, OmegaConf
__all__ = ["omega_conf_to_dataclass"]
__all__ = ["omega_conf_to_dataclass", "validate_config"]
def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any:
@ -69,3 +69,129 @@ def update_dict_with_config(dictionary: dict, config: DictConfig):
for key in dictionary:
if hasattr(config, key):
dictionary[key] = getattr(config, key)
def validate_config(
config: DictConfig,
use_reference_policy: bool,
use_critic: bool,
) -> None:
"""Validate an OmegaConf DictConfig.
Args:
config (DictConfig): The OmegaConf DictConfig to validate.
use_reference_policy (bool): is ref policy needed
use_critic (bool): is critic needed
"""
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if config.actor_rollout_ref.actor.strategy == "megatron":
model_parallel_size = (
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
)
assert (
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
), (
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % minimal_bsz == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
f"({minimal_bsz})"
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
"""Validate mutually exclusive micro batch size configuration options.
Ensures that users don't set both deprecated micro_batch_size and
the new micro_batch_size_per_gpu parameters simultaneously.
Args:
mbs: Deprecated micro batch size parameter value.
mbs_per_gpu: New micro batch size per GPU parameter value.
name (str): Configuration section name for error messages.
Raises:
ValueError: If both parameters are set or neither is set.
"""
settings = {
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
}
if name in settings:
param = settings[name]
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)
# Actor validation done in ActorConfig.__post_init__ and validate()
actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.ref",
)
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.rollout",
)
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print("NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if use_critic:
critic_config = omega_conf_to_dataclass(config.critic)
critic_config.validate(n_gpus, config.data.train_batch_size)
if config.data.get("val_batch_size", None) is not None:
print(
"WARNING: val_batch_size is deprecated."
+ " Validation datasets are sent to inference engines as a whole batch,"
+ " which will schedule the memory themselves."
)
# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, (
"validation gen temperature should be greater than 0 when enabling do_sample"
)
print("[validate_config] All configuration checks passed successfully!")