Files
MindSpeed-RL/mindspeed_rl/trainer/grpo_trainer_hybrid.py
panchenyi 5cc9a136a8 !507 partial rollout
Merge pull request !507 from panchenyi/master
2025-08-04 11:07:37 +00:00

272 lines
13 KiB
Python

# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
import copy
import time
from typing import List, Union, Dict
import time
import ray
import torch
from codetiming import Timer
from torch.utils.data import DataLoader
from mindspeed_rl.trainer.utils.transfer_dock import put_prompts_experience
from mindspeed_rl.utils.tokenizer import BaseTokenizer
from mindspeed_rl.workers.rule_reward import RuleReward
from mindspeed_rl.trainer.base import RayBaseTrainer
from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig
from mindspeed_rl.trainer.utils import GRPOTransferDock, MMGRPOTransferDock
from mindspeed_rl.trainer.utils.compute_utils import compute_advantage, compute_grpo_data_metrics
from mindspeed_rl.workers.scheduler.launcher import RayActorGroup
from mindspeed_rl.utils.loggers import Loggers
from mindspeed_rl.utils.metrics import Metric
from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort, is_multimodal
class RayGRPOTrainer(RayBaseTrainer):
"""
RayGRPOTrainer class. This trainer runs on the driver process on a single CPU/GPU node.
Args:
actor_worker: RayActorGroup The actor worker group.
ref_worker: RayActorGroup The reference worker group.
reward_list: List[Union[RayActorGroup, RuleReward]] List of reward workers or rule-based rewards.
train_iters: int = 1 The number of training iterations.
save_interval: int = 1 The interval (in iterations) for saving checkpoints.
kl_ctrl_type: str = 'fixed' The type of KL divergence control (e.g., 'fixed', 'adaptive').
adv_estimator: str = "group_norm" The method for estimating advantages (e.g., 'group_norm', 'mean').
kl_horizon: int = 1000 The time horizon for KL divergence control (used in adaptive methods).
kl_target: float = 100.0 The target value for KL divergence (used in adaptive methods).
init_kl_coef: float = 0.01 The initial coefficient for KL divergence penalty.
global_batch_size: int = 1 The global batch size for training (number of prompts per iteration).
n_samples_per_prompt: int = 1 The number of samples generated per prompt.
tokenizer: BaseTokenizer = None tokenizer to use.
dataset_additional_keys: List[str] = None Additional keys to include in the dataset.
blocking: bool = False Whether to enable blocking mode.
num_cpus_for_local_task: int = 1 Number of CPUs for local ray task.
**kwargs: Additional parameters for base class argument passing.
"""
def __init__(
self,
actor_worker: RayActorGroup,
ref_worker: RayActorGroup,
reward_list: List[Union[RayActorGroup, RuleReward]],
train_iters: int = 1,
save_interval: int = 1,
kl_ctrl_type: str = 'fixed',
adv_estimator: str = "group_norm",
kl_horizon: int = 1000,
kl_target: float = 100.0,
init_kl_coef: float = 0.01,
global_batch_size: int = 1,
micro_batch_size: int = 1,
n_samples_per_prompt: int = 1,
tokenizer: BaseTokenizer = None,
dataset_additional_keys: List[str] = None,
blocking: bool = False,
guarantee_order: bool = False,
num_cpus_for_local_task: int = 1,
partial_rollout_max_split: int = 1,
**kwargs
):
super().__init__(
actor_worker,
ref_worker,
reward_list,
train_iters=train_iters,
save_interval=save_interval,
kl_ctrl_type=kl_ctrl_type,
kl_horizon=kl_horizon,
kl_target=kl_target,
adv_estimator=adv_estimator,
init_kl_coef=init_kl_coef,
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
n_samples_per_prompt=n_samples_per_prompt,
tokenizer=tokenizer,
dataset_additional_keys=dataset_additional_keys,
blocking=blocking,
guarantee_order=guarantee_order,
num_cpus_for_local_task=num_cpus_for_local_task,
partial_rollout_max_split=partial_rollout_max_split,
**kwargs
)
self.transfer_dock = None
self.mm_transfer_dock = None
self.enable_partial_rollout = self.partial_rollout_max_split > 1
self.metrics = Metric()
if self.enable_partial_rollout:
self.td_max_len = self.global_batch_size * 2
else:
self.td_max_len = self.global_batch_size
self.transfer_dock_init()
self.kwargs = kwargs
self.set_actor_log_prob_skip_flag()
def transfer_dock_init(self):
self.transfer_dock = GRPOTransferDock.remote(
prompts_num=self.td_max_len, # max sample num
n_samples_per_prompt=self.n_samples_per_prompt,
metrics=self.metrics,
max_age=self.partial_rollout_max_split,
GBS_train=self.global_batch_size, # GBS_train
addition_columns=self.dataset_additional_keys
)
if is_multimodal():
self.mm_transfer_dock = MMGRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt)
self.actor_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock)
self.ref_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock)
for reward in self.reward_list:
if hasattr(reward, 'sync_init_transfer_dock'):
reward.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock)
else:
reward.init_transfer_dock.remote(self.transfer_dock, self.mm_transfer_dock)
def set_actor_log_prob_skip_flag(self):
global_batch_size = self.actor_worker.megatron_config.global_batch_size
mini_batch_size = self.actor_worker.rl_config.mini_batch_size
n_samples_per_prompt = self.actor_worker.rl_config.n_samples_per_prompt
epochs = self.actor_worker.rl_config.epochs
self.skip_actor_log_prob = (global_batch_size * n_samples_per_prompt == mini_batch_size and epochs == 1)
self.actor_worker.skip_actor_log_prob = self.skip_actor_log_prob
def fit(self, data_iters):
"""
The utils loop of GRPO
"""
logger = Loggers('grpo_trainer_hybrid')
metrics = Metric()
iteration = self.actor_worker.get_iteration()
if self.blocking:
logger.info('sync start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
else:
logger.info('async start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
if self.enable_partial_rollout:
first_batch = next(data_iters)
batch, indexes = put_prompts_experience(first_batch, self.n_samples_per_prompt,
self.dataset_additional_keys)
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
logger.info(f'training start, put first batch')
while iteration < self.train_iters:
last_iter = iteration == self.train_iters - 1
with Timer(name='iteration', logger=None) as all_timer:
batch = next(data_iters)
if self.enable_partial_rollout:
if not last_iter: # and batch is not None: # None?
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt,
self.dataset_additional_keys,
add_another_batch=True)
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
else:
batch_dict, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys)
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch_dict, indexes=indexes, is_prompt=True))
if is_multimodal():
ray.get(self.mm_transfer_dock.clear.remote())
ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
self.actor_worker.generate_sequences(blocking=self.blocking)
# compute rm scores.
rule_reward = []
for reward_worker in self.reward_list:
if isinstance(reward_worker, RayActorGroup):
reward_worker.compute_rm_score(blocking=self.blocking)
else:
rule_reward.append(reward_worker.compute_rm_score.remote())
ray.get(rule_reward)
# compute advantages, executed on the driver process
self.compute_advantage(blocking=True, guarantee_order=self.guarantee_order)
# compute reference log_prob
self.ref_worker.compute_ref_log_prob(blocking=self.blocking)
# compute old log_prob
if not self.skip_actor_log_prob:
self.actor_worker.compute_log_prob(blocking=self.blocking)
self.actor_worker.wait_all_ref_objs_run_over()
self.ref_worker.wait_all_ref_objs_run_over()
for reward in self.reward_list:
if hasattr(reward, 'wait_all_ref_objs_run_over'):
reward.wait_all_ref_objs_run_over()
# update actor
self.actor_worker.update(self.kl_ctrl, self.skip_actor_log_prob)
# collect metrics
grpo_data_metrics = compute_grpo_data_metrics(self.transfer_dock,
self.global_batch_size * self.n_samples_per_prompt,
self.tokenizer,
self.global_batch_size * self.n_samples_per_prompt,
self.guarantee_order)
metrics_result = ray.get(self.transfer_dock.get_metrics.remote())
metrics_result = metrics_post_processing(metrics_result)
metrics_result = metrics_sort(metrics_result, all_timer.last)
log_max_throughput = self.actor_worker.rl_config.log_max_throughput
tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt,
all_timer.last, log_max_throughput)
update_tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt,
metrics_result["timing/update"], log_max_throughput)
vllm_tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt,
metrics_result["timing/rollout"], log_max_throughput)
metrics.update(value=metrics_result)
metrics.update(value=grpo_data_metrics)
metrics.update("e2e_tps", tps)
metrics.update("update_tps", update_tps)
metrics.update("vllm_tps", vllm_tps)
iteration += 1
logger.info(metrics.metric, iteration, self.train_iters)
ray.get(self.transfer_dock.clear.remote())
if self.tensorboard is not None:
for k, v in metrics.metric.items():
self.tensorboard.add_scalar(f"train/{k}", v, iteration)
if self.wandb is not None:
self.wandb.log_metrics(metrics.metric, iteration)
if iteration % self.save_interval == 0 or iteration == self.train_iters:
self.save_checkpoint(iteration)
logger.info('after grpo training is done')
ray.shutdown()
def compute_advantage(self, blocking=False, guarantee_order=False):
experience_count = self.micro_batch_size
start_adv_time = time.time()
compute_advantage_ref = compute_advantage.options(num_cpus=self.num_cpus_for_local_task).remote(
self.transfer_dock,
self.gamma,
self.lam,
adv_estimator=self.adv_estimator,
experience_count=experience_count,
tokenizer=self.tokenizer,
global_batch_size=self.global_batch_size * self.n_samples_per_prompt,
guarantee_order=guarantee_order,
n_sample_per_prompt=self.actor_worker.rl_config.n_samples_per_prompt
)
if blocking:
ray.get(compute_advantage_ref)
end_adv_time = time.time()
ray.get(
self.transfer_dock.update_metrics.remote(
"timing/adv",
value=[round(end_adv_time, 4), round(start_adv_time, 4)],
cumulate=True
)
)
ray.get(
self.transfer_dock.update_metrics.remote(
"end_time/end_adv_time",
value=[round(end_adv_time, 4)],
cumulate=True
)
)
def save_checkpoint(self, iteration: int):
self.actor_worker.save_checkpoint(iteration)