mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
272 lines
13 KiB
Python
272 lines
13 KiB
Python
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
|
|
import copy
|
|
import time
|
|
from typing import List, Union, Dict
|
|
import time
|
|
import ray
|
|
import torch
|
|
from codetiming import Timer
|
|
from torch.utils.data import DataLoader
|
|
|
|
from mindspeed_rl.trainer.utils.transfer_dock import put_prompts_experience
|
|
from mindspeed_rl.utils.tokenizer import BaseTokenizer
|
|
from mindspeed_rl.workers.rule_reward import RuleReward
|
|
from mindspeed_rl.trainer.base import RayBaseTrainer
|
|
from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig
|
|
from mindspeed_rl.trainer.utils import GRPOTransferDock, MMGRPOTransferDock
|
|
from mindspeed_rl.trainer.utils.compute_utils import compute_advantage, compute_grpo_data_metrics
|
|
from mindspeed_rl.workers.scheduler.launcher import RayActorGroup
|
|
from mindspeed_rl.utils.loggers import Loggers
|
|
from mindspeed_rl.utils.metrics import Metric
|
|
from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort, is_multimodal
|
|
|
|
|
|
class RayGRPOTrainer(RayBaseTrainer):
|
|
"""
|
|
RayGRPOTrainer class. This trainer runs on the driver process on a single CPU/GPU node.
|
|
|
|
Args:
|
|
actor_worker: RayActorGroup The actor worker group.
|
|
ref_worker: RayActorGroup The reference worker group.
|
|
reward_list: List[Union[RayActorGroup, RuleReward]] List of reward workers or rule-based rewards.
|
|
train_iters: int = 1 The number of training iterations.
|
|
save_interval: int = 1 The interval (in iterations) for saving checkpoints.
|
|
kl_ctrl_type: str = 'fixed' The type of KL divergence control (e.g., 'fixed', 'adaptive').
|
|
adv_estimator: str = "group_norm" The method for estimating advantages (e.g., 'group_norm', 'mean').
|
|
kl_horizon: int = 1000 The time horizon for KL divergence control (used in adaptive methods).
|
|
kl_target: float = 100.0 The target value for KL divergence (used in adaptive methods).
|
|
init_kl_coef: float = 0.01 The initial coefficient for KL divergence penalty.
|
|
global_batch_size: int = 1 The global batch size for training (number of prompts per iteration).
|
|
n_samples_per_prompt: int = 1 The number of samples generated per prompt.
|
|
tokenizer: BaseTokenizer = None tokenizer to use.
|
|
dataset_additional_keys: List[str] = None Additional keys to include in the dataset.
|
|
blocking: bool = False Whether to enable blocking mode.
|
|
num_cpus_for_local_task: int = 1 Number of CPUs for local ray task.
|
|
**kwargs: Additional parameters for base class argument passing.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
actor_worker: RayActorGroup,
|
|
ref_worker: RayActorGroup,
|
|
reward_list: List[Union[RayActorGroup, RuleReward]],
|
|
train_iters: int = 1,
|
|
save_interval: int = 1,
|
|
kl_ctrl_type: str = 'fixed',
|
|
adv_estimator: str = "group_norm",
|
|
kl_horizon: int = 1000,
|
|
kl_target: float = 100.0,
|
|
init_kl_coef: float = 0.01,
|
|
global_batch_size: int = 1,
|
|
micro_batch_size: int = 1,
|
|
n_samples_per_prompt: int = 1,
|
|
tokenizer: BaseTokenizer = None,
|
|
dataset_additional_keys: List[str] = None,
|
|
blocking: bool = False,
|
|
guarantee_order: bool = False,
|
|
num_cpus_for_local_task: int = 1,
|
|
partial_rollout_max_split: int = 1,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
actor_worker,
|
|
ref_worker,
|
|
reward_list,
|
|
train_iters=train_iters,
|
|
save_interval=save_interval,
|
|
kl_ctrl_type=kl_ctrl_type,
|
|
kl_horizon=kl_horizon,
|
|
kl_target=kl_target,
|
|
adv_estimator=adv_estimator,
|
|
init_kl_coef=init_kl_coef,
|
|
global_batch_size=global_batch_size,
|
|
micro_batch_size=micro_batch_size,
|
|
n_samples_per_prompt=n_samples_per_prompt,
|
|
tokenizer=tokenizer,
|
|
dataset_additional_keys=dataset_additional_keys,
|
|
blocking=blocking,
|
|
guarantee_order=guarantee_order,
|
|
num_cpus_for_local_task=num_cpus_for_local_task,
|
|
partial_rollout_max_split=partial_rollout_max_split,
|
|
**kwargs
|
|
)
|
|
|
|
self.transfer_dock = None
|
|
self.mm_transfer_dock = None
|
|
self.enable_partial_rollout = self.partial_rollout_max_split > 1
|
|
self.metrics = Metric()
|
|
if self.enable_partial_rollout:
|
|
self.td_max_len = self.global_batch_size * 2
|
|
else:
|
|
self.td_max_len = self.global_batch_size
|
|
self.transfer_dock_init()
|
|
self.kwargs = kwargs
|
|
self.set_actor_log_prob_skip_flag()
|
|
|
|
def transfer_dock_init(self):
|
|
self.transfer_dock = GRPOTransferDock.remote(
|
|
prompts_num=self.td_max_len, # max sample num
|
|
n_samples_per_prompt=self.n_samples_per_prompt,
|
|
metrics=self.metrics,
|
|
max_age=self.partial_rollout_max_split,
|
|
GBS_train=self.global_batch_size, # GBS_train
|
|
addition_columns=self.dataset_additional_keys
|
|
)
|
|
if is_multimodal():
|
|
self.mm_transfer_dock = MMGRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt)
|
|
|
|
self.actor_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock)
|
|
self.ref_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock)
|
|
for reward in self.reward_list:
|
|
if hasattr(reward, 'sync_init_transfer_dock'):
|
|
reward.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock)
|
|
else:
|
|
reward.init_transfer_dock.remote(self.transfer_dock, self.mm_transfer_dock)
|
|
|
|
def set_actor_log_prob_skip_flag(self):
|
|
global_batch_size = self.actor_worker.megatron_config.global_batch_size
|
|
mini_batch_size = self.actor_worker.rl_config.mini_batch_size
|
|
n_samples_per_prompt = self.actor_worker.rl_config.n_samples_per_prompt
|
|
epochs = self.actor_worker.rl_config.epochs
|
|
self.skip_actor_log_prob = (global_batch_size * n_samples_per_prompt == mini_batch_size and epochs == 1)
|
|
self.actor_worker.skip_actor_log_prob = self.skip_actor_log_prob
|
|
|
|
def fit(self, data_iters):
|
|
"""
|
|
The utils loop of GRPO
|
|
"""
|
|
logger = Loggers('grpo_trainer_hybrid')
|
|
metrics = Metric()
|
|
|
|
iteration = self.actor_worker.get_iteration()
|
|
|
|
if self.blocking:
|
|
logger.info('sync start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
|
|
else:
|
|
logger.info('async start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
|
|
if self.enable_partial_rollout:
|
|
first_batch = next(data_iters)
|
|
batch, indexes = put_prompts_experience(first_batch, self.n_samples_per_prompt,
|
|
self.dataset_additional_keys)
|
|
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
|
|
logger.info(f'training start, put first batch')
|
|
|
|
while iteration < self.train_iters:
|
|
last_iter = iteration == self.train_iters - 1
|
|
with Timer(name='iteration', logger=None) as all_timer:
|
|
batch = next(data_iters)
|
|
if self.enable_partial_rollout:
|
|
if not last_iter: # and batch is not None: # None?
|
|
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt,
|
|
self.dataset_additional_keys,
|
|
add_another_batch=True)
|
|
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
|
|
else:
|
|
batch_dict, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys)
|
|
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch_dict, indexes=indexes, is_prompt=True))
|
|
if is_multimodal():
|
|
ray.get(self.mm_transfer_dock.clear.remote())
|
|
ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
|
|
|
|
self.actor_worker.generate_sequences(blocking=self.blocking)
|
|
|
|
# compute rm scores.
|
|
rule_reward = []
|
|
for reward_worker in self.reward_list:
|
|
if isinstance(reward_worker, RayActorGroup):
|
|
reward_worker.compute_rm_score(blocking=self.blocking)
|
|
else:
|
|
rule_reward.append(reward_worker.compute_rm_score.remote())
|
|
ray.get(rule_reward)
|
|
|
|
# compute advantages, executed on the driver process
|
|
self.compute_advantage(blocking=True, guarantee_order=self.guarantee_order)
|
|
|
|
# compute reference log_prob
|
|
self.ref_worker.compute_ref_log_prob(blocking=self.blocking)
|
|
|
|
# compute old log_prob
|
|
if not self.skip_actor_log_prob:
|
|
self.actor_worker.compute_log_prob(blocking=self.blocking)
|
|
|
|
self.actor_worker.wait_all_ref_objs_run_over()
|
|
|
|
self.ref_worker.wait_all_ref_objs_run_over()
|
|
for reward in self.reward_list:
|
|
if hasattr(reward, 'wait_all_ref_objs_run_over'):
|
|
reward.wait_all_ref_objs_run_over()
|
|
|
|
# update actor
|
|
self.actor_worker.update(self.kl_ctrl, self.skip_actor_log_prob)
|
|
|
|
# collect metrics
|
|
grpo_data_metrics = compute_grpo_data_metrics(self.transfer_dock,
|
|
self.global_batch_size * self.n_samples_per_prompt,
|
|
self.tokenizer,
|
|
self.global_batch_size * self.n_samples_per_prompt,
|
|
self.guarantee_order)
|
|
metrics_result = ray.get(self.transfer_dock.get_metrics.remote())
|
|
|
|
metrics_result = metrics_post_processing(metrics_result)
|
|
metrics_result = metrics_sort(metrics_result, all_timer.last)
|
|
log_max_throughput = self.actor_worker.rl_config.log_max_throughput
|
|
tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt,
|
|
all_timer.last, log_max_throughput)
|
|
update_tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt,
|
|
metrics_result["timing/update"], log_max_throughput)
|
|
vllm_tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt,
|
|
metrics_result["timing/rollout"], log_max_throughput)
|
|
metrics.update(value=metrics_result)
|
|
metrics.update(value=grpo_data_metrics)
|
|
metrics.update("e2e_tps", tps)
|
|
metrics.update("update_tps", update_tps)
|
|
metrics.update("vllm_tps", vllm_tps)
|
|
iteration += 1
|
|
logger.info(metrics.metric, iteration, self.train_iters)
|
|
ray.get(self.transfer_dock.clear.remote())
|
|
if self.tensorboard is not None:
|
|
for k, v in metrics.metric.items():
|
|
self.tensorboard.add_scalar(f"train/{k}", v, iteration)
|
|
if self.wandb is not None:
|
|
self.wandb.log_metrics(metrics.metric, iteration)
|
|
if iteration % self.save_interval == 0 or iteration == self.train_iters:
|
|
self.save_checkpoint(iteration)
|
|
|
|
logger.info('after grpo training is done')
|
|
ray.shutdown()
|
|
|
|
def compute_advantage(self, blocking=False, guarantee_order=False):
|
|
experience_count = self.micro_batch_size
|
|
|
|
start_adv_time = time.time()
|
|
compute_advantage_ref = compute_advantage.options(num_cpus=self.num_cpus_for_local_task).remote(
|
|
self.transfer_dock,
|
|
self.gamma,
|
|
self.lam,
|
|
adv_estimator=self.adv_estimator,
|
|
experience_count=experience_count,
|
|
tokenizer=self.tokenizer,
|
|
global_batch_size=self.global_batch_size * self.n_samples_per_prompt,
|
|
guarantee_order=guarantee_order,
|
|
n_sample_per_prompt=self.actor_worker.rl_config.n_samples_per_prompt
|
|
)
|
|
if blocking:
|
|
ray.get(compute_advantage_ref)
|
|
end_adv_time = time.time()
|
|
ray.get(
|
|
self.transfer_dock.update_metrics.remote(
|
|
"timing/adv",
|
|
value=[round(end_adv_time, 4), round(start_adv_time, 4)],
|
|
cumulate=True
|
|
)
|
|
)
|
|
ray.get(
|
|
self.transfer_dock.update_metrics.remote(
|
|
"end_time/end_adv_time",
|
|
value=[round(end_adv_time, 4)],
|
|
cumulate=True
|
|
)
|
|
)
|
|
|
|
def save_checkpoint(self, iteration: int):
|
|
self.actor_worker.save_checkpoint(iteration)
|