mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
217 lines
8.1 KiB
Python
217 lines
8.1 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
|
|
"""
|
|
|
|
import hydra
|
|
import ray
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
from split_monkey_patch import fit
|
|
|
|
from verl import DataProto
|
|
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
|
from verl.utils.reward_score import gsm8k, math_reward
|
|
|
|
|
|
def _select_rm_score_fn(data_source):
|
|
if data_source == "openai/gsm8k":
|
|
return gsm8k.compute_score
|
|
elif data_source == "lighteval/MATH":
|
|
return math_reward.compute_score
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
class RewardManager:
|
|
def __init__(self, tokenizer, num_examine) -> None:
|
|
self.tokenizer = tokenizer
|
|
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
|
|
|
|
def __call__(self, data: DataProto, return_dict: bool = False):
|
|
"""We will expand this function gradually based on the available datasets"""
|
|
|
|
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
|
|
if "rm_scores" in data.batch.keys():
|
|
return data.batch["rm_scores"]
|
|
|
|
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
|
|
|
|
already_print_data_sources = {}
|
|
|
|
for i in range(len(data)):
|
|
data_item = data[i] # DataProtoItem
|
|
|
|
prompt_ids = data_item.batch["prompts"]
|
|
|
|
prompt_length = prompt_ids.shape[-1]
|
|
|
|
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
|
|
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
|
|
|
|
response_ids = data_item.batch["responses"]
|
|
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
|
|
valid_response_ids = response_ids[:valid_response_length]
|
|
|
|
# decode
|
|
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
|
|
sequences_str = self.tokenizer.decode(sequences)
|
|
|
|
ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
|
|
|
|
# select rm_score
|
|
data_source = data_item.non_tensor_batch["data_source"]
|
|
compute_score_fn = _select_rm_score_fn(data_source)
|
|
|
|
score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
|
|
reward_tensor[i, valid_response_length - 1] = score
|
|
|
|
if data_source not in already_print_data_sources:
|
|
already_print_data_sources[data_source] = 0
|
|
|
|
if already_print_data_sources[data_source] < self.num_examine:
|
|
already_print_data_sources[data_source] += 1
|
|
print(sequences_str)
|
|
|
|
if return_dict:
|
|
return {"reward_tensor": reward_tensor}
|
|
else:
|
|
return reward_tensor
|
|
|
|
|
|
@hydra.main(config_path="config", config_name="ppo_trainer_split", version_base=None)
|
|
def main(config):
|
|
if not ray.is_initialized():
|
|
# this is for local ray cluster
|
|
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}
|
|
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
|
|
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
|
|
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
|
|
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
|
|
print(f"ray init kwargs: {ray_init_kwargs}")
|
|
ray.init(**OmegaConf.to_container(ray_init_kwargs))
|
|
|
|
ray.get(main_task.remote(config))
|
|
|
|
|
|
@ray.remote
|
|
def main_task(config):
|
|
# print initial config
|
|
from pprint import pprint
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from verl.utils.fs import copy_to_local
|
|
|
|
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
|
OmegaConf.resolve(config)
|
|
|
|
# download the checkpoint from hdfs
|
|
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
|
|
|
# instantiate tokenizer
|
|
from verl.utils import hf_tokenizer
|
|
|
|
tokenizer = hf_tokenizer(local_path)
|
|
|
|
# define worker classes
|
|
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
|
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
|
from verl.single_controller.ray import RayWorkerGroup
|
|
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
|
|
|
|
ray_worker_group_cls = RayWorkerGroup
|
|
|
|
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
|
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
|
from verl.single_controller.ray import RayWorkerGroup
|
|
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
|
|
|
ray_worker_group_cls = RayWorkerGroup
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
|
|
|
role_worker_mapping = {
|
|
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
|
Role.Critic: ray.remote(CriticWorker),
|
|
}
|
|
|
|
# NOTE: initialze two resource pool
|
|
actor_rollout_ref_pool_id = "actor_rollout_ref_pool"
|
|
critic_pool_id = "critic_pool"
|
|
if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:
|
|
resource_pool_spec = {
|
|
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,
|
|
critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,
|
|
}
|
|
else:
|
|
resource_pool_spec = {
|
|
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
|
|
critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
|
|
}
|
|
print(f"resource_pool_spec: {resource_pool_spec}")
|
|
mapping = {
|
|
Role.ActorRollout: actor_rollout_ref_pool_id,
|
|
Role.Critic: critic_pool_id,
|
|
}
|
|
|
|
# use reference model
|
|
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
|
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
|
mapping[Role.RefPolicy] = actor_rollout_ref_pool_id
|
|
|
|
# we should adopt a multi-source reward function here
|
|
# - for rule-based rm, we directly call a reward score
|
|
# - for model-based rm, we call a model
|
|
# - for code related prompt, we send to a sandbox if there are test cases
|
|
# - finally, we combine all the rewards together
|
|
# - The reward type depends on the tag of the data
|
|
if config.reward_model.enable:
|
|
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
|
|
from verl.workers.fsdp_workers import RewardModelWorker
|
|
elif config.reward_model.strategy == "megatron":
|
|
from verl.workers.megatron_workers import RewardModelWorker
|
|
else:
|
|
raise NotImplementedError
|
|
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
|
mapping[Role.RewardModel] = critic_pool_id
|
|
|
|
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)
|
|
|
|
# Note that we always use function-based RM for validation
|
|
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)
|
|
|
|
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
|
|
|
RayPPOTrainer.fit = fit
|
|
trainer = RayPPOTrainer(
|
|
config=config,
|
|
tokenizer=tokenizer,
|
|
role_worker_mapping=role_worker_mapping,
|
|
resource_pool_manager=resource_pool_manager,
|
|
ray_worker_group_cls=ray_worker_group_cls,
|
|
reward_fn=reward_fn,
|
|
val_reward_fn=val_reward_fn,
|
|
)
|
|
trainer.init_workers()
|
|
trainer.fit()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|