mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
### What does this PR do? Make main ppo script validate config as soon as all needed info is available. this enables the script to fail as fast as possible in case of bug in config. New changes would avoid downloading and loading tokenizer and loading data before validating config solve #3182 ### Design & Code Changes Isolated config validation in utils (out of PpoRayTrainer) and call it from main_ppo as soon as possible.
164 lines
5.8 KiB
Python
164 lines
5.8 KiB
Python
# Copyright 2024 PRIME team and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
|
|
"""
|
|
|
|
import hydra
|
|
import ray
|
|
from omegaconf import OmegaConf
|
|
|
|
from verl.trainer.ppo.utils import need_reference_policy
|
|
from verl.utils.config import validate_config
|
|
|
|
from .prime_ray_trainer import RayPRIMETrainer
|
|
|
|
|
|
@hydra.main(config_path="config", config_name="prime_trainer", version_base=None)
|
|
def main(config):
|
|
run_prime(config)
|
|
|
|
|
|
def run_prime(config, compute_score=None):
|
|
if not ray.is_initialized():
|
|
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}
|
|
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
|
|
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
|
|
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
|
|
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
|
|
print(f"ray init kwargs: {ray_init_kwargs}")
|
|
# this is for local ray cluster
|
|
ray.init(**OmegaConf.to_container(ray_init_kwargs))
|
|
|
|
ray.get(main_task.remote(config, compute_score))
|
|
|
|
|
|
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
|
|
def main_task(config, compute_score=None):
|
|
# print initial config
|
|
from pprint import pprint
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from verl.utils.fs import copy_local_path_from_hdfs
|
|
|
|
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
|
OmegaConf.resolve(config)
|
|
|
|
# define worker classes
|
|
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
|
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
|
from verl.single_controller.ray import RayWorkerGroup
|
|
from verl.workers.fsdp_workers import ActorRolloutRefWorker
|
|
|
|
ray_worker_group_cls = RayWorkerGroup
|
|
|
|
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
|
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
|
from verl.single_controller.ray import RayWorkerGroup
|
|
from verl.workers.megatron_workers import ActorRolloutRefWorker
|
|
|
|
ray_worker_group_cls = RayWorkerGroup
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
|
|
|
role_worker_mapping = {
|
|
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
|
}
|
|
|
|
global_pool_id = "global_pool"
|
|
resource_pool_spec = {
|
|
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
|
}
|
|
mapping = {
|
|
Role.ActorRollout: global_pool_id,
|
|
}
|
|
|
|
# use reference model
|
|
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
|
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
|
mapping[Role.RefPolicy] = global_pool_id
|
|
|
|
if config.reward_model.enable:
|
|
from .prime_fsdp_workers import PRIMERewardModelWorker
|
|
|
|
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
|
|
mapping[Role.RewardModel] = global_pool_id
|
|
|
|
# validate config
|
|
# TODO: Additional config checks can be added with proper function under prime recipe
|
|
validate_config(
|
|
config=config,
|
|
use_reference_policy=need_reference_policy(role_worker_mapping),
|
|
use_critic=False,
|
|
)
|
|
|
|
# download the checkpoint from hdfs
|
|
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
|
|
|
# instantiate tokenizer
|
|
from verl.utils import hf_tokenizer
|
|
|
|
tokenizer = hf_tokenizer(local_path)
|
|
reward_manager_name = config.reward_model.get("reward_manager", "naive")
|
|
if reward_manager_name == "naive":
|
|
from verl.workers.reward_manager import NaiveRewardManager
|
|
|
|
reward_manager_cls = NaiveRewardManager
|
|
elif reward_manager_name == "prime":
|
|
from verl.workers.reward_manager import PrimeRewardManager
|
|
|
|
reward_manager_cls = PrimeRewardManager
|
|
else:
|
|
raise NotImplementedError
|
|
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
|
|
|
|
# Note that we always use function-based RM for validation
|
|
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
|
|
|
|
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
|
|
|
trainer = RayPRIMETrainer(
|
|
config=config,
|
|
tokenizer=tokenizer,
|
|
role_worker_mapping=role_worker_mapping,
|
|
resource_pool_manager=resource_pool_manager,
|
|
ray_worker_group_cls=ray_worker_group_cls,
|
|
reward_fn=reward_fn,
|
|
val_reward_fn=val_reward_fn,
|
|
)
|
|
trainer.init_workers()
|
|
trainer.fit()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|