Files
verl/recipe/prime/main_prime.py
Slim Frikha 7592d69cbb [trainer] refactor: PPO config validation fast fail (#3187)
### What does this PR do?

Make main ppo script validate config as soon as all needed info is
available. this enables the script to fail as fast as possible in case
of bug in config.
New changes would avoid downloading and loading tokenizer and loading
data before validating config
solve #3182 

### Design & Code Changes

Isolated config validation in utils (out of PpoRayTrainer) and call it
from main_ppo as soon as possible.
2025-08-26 10:31:39 +08:00

164 lines
5.8 KiB
Python

# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import hydra
import ray
from omegaconf import OmegaConf
from verl.trainer.ppo.utils import need_reference_policy
from verl.utils.config import validate_config
from .prime_ray_trainer import RayPRIMETrainer
@hydra.main(config_path="config", config_name="prime_trainer", version_base=None)
def main(config):
run_prime(config)
def run_prime(config, compute_score=None):
if not ray.is_initialized():
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
print(f"ray init kwargs: {ray_init_kwargs}")
# this is for local ray cluster
ray.init(**OmegaConf.to_container(ray_init_kwargs))
ray.get(main_task.remote(config, compute_score))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_local_path_from_hdfs
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
}
# use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
if config.reward_model.enable:
from .prime_fsdp_workers import PRIMERewardModelWorker
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# validate config
# TODO: Additional config checks can be added with proper function under prime recipe
validate_config(
config=config,
use_reference_policy=need_reference_policy(role_worker_mapping),
use_critic=False,
)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == "naive":
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == "prime":
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
else:
raise NotImplementedError
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPRIMETrainer(
config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == "__main__":
main()