!578 【多模态】【feat.】Supports DAPO

Merge pull request !578 from wanghao/dapo
This commit is contained in:
wanghao
2025-09-05 09:35:38 +00:00
committed by i-robot
parent c5318e7b85
commit d9212e12ae
11 changed files with 110 additions and 24 deletions

View File

@ -8,12 +8,13 @@ from mindspeed_rl.utils.tokenizer import BaseTokenizer
from mindspeed_rl.workers.dynamic_sampling import DynamicSampling
from mindspeed_rl.workers.rule_reward import RuleReward
from mindspeed_rl.trainer.base import RayBaseTrainer
from mindspeed_rl.trainer.utils import MMGRPOTransferDock
from mindspeed_rl.trainer.utils.transfer_dock import GRPOTransferDock, put_prompts_experience
from mindspeed_rl.trainer.utils.compute_utils import compute_advantage, compute_dapo_data_metrics
from mindspeed_rl.workers.scheduler.launcher import RayActorGroup
from mindspeed_rl.utils.loggers import Loggers
from mindspeed_rl.utils.metrics import Metric
from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort
from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort, is_multimodal
class RayDAPOTrainer(RayBaseTrainer):
@ -87,6 +88,8 @@ class RayDAPOTrainer(RayBaseTrainer):
self.transfer_dock = None
self.sampling_transfer_dock = None
self.mm_transfer_dock = None
self.mm_sampling_transfer_dock = None
self.metrics = Metric()
self.kwargs = kwargs
self.should_filter = self.kwargs['filter_groups_enable']
@ -121,24 +124,28 @@ class RayDAPOTrainer(RayBaseTrainer):
metrics=self.metrics,
addition_columns=self.addition_columns,
addition_consumers=self.addition_consumers)
if is_multimodal():
self.mm_transfer_dock = MMGRPOTransferDock.remote(self.max_num_prompt_in_batch, self.n_samples_per_prompt)
self.mm_sampling_transfer_dock = MMGRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt)
for sampling in self.dynamic_sampling_list:
sampling.init_transfer_dock.remote(self.transfer_dock,
sampling_transfer_dock=self.sampling_transfer_dock)
sampling.init_transfer_dock.remote(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
else:
self.transfer_dock = GRPOTransferDock.remote(self.td_max_len, self.n_samples_per_prompt,
max_age=self.partial_rollout_max_split,
GBS_train=self.global_batch_size,
metrics=self.metrics, addition_columns=self.addition_columns,
addition_consumers=self.addition_consumers)
if is_multimodal():
self.mm_transfer_dock = MMGRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt)
self.actor_worker.sync_init_transfer_dock(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
self.actor_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
if self.ref_worker:
self.ref_worker.sync_init_transfer_dock(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
self.ref_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
for reward in self.reward_list:
if hasattr(reward, 'sync_init_transfer_dock'):
reward.sync_init_transfer_dock(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
reward.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
else:
reward.init_transfer_dock.remote(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
reward.init_transfer_dock.remote(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
def set_actor_log_prob_skip_flag(self):
if self.should_filter:
@ -155,10 +162,16 @@ class RayDAPOTrainer(RayBaseTrainer):
ray.get(self.sampling_transfer_dock.clear.remote(consumer='dynamic_sampling'))
index_list = ray.get(self.sampling_transfer_dock.prefetch_request_index.remote(data_num))
if index_list:
if is_multimodal():
ray.get(self.mm_sampling_transfer_dock.clear.remote())
ray.get(self.mm_sampling_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys,
indexes=index_list, add_another_batch=add_another_batch)
ray.get(self.sampling_transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
else:
if is_multimodal():
ray.get(self.mm_transfer_dock.clear.remote())
ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys,
add_another_batch=add_another_batch)
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))

View File

@ -43,6 +43,7 @@ class MMGRPOTransferDock(TransferDock):
self.n_samples_per_prompt = n_samples_per_prompt
self.consumer_columns = {
"actor_rollout": ["image", "image_shape", "image_num", "video", "video_shape", "video_fps", "video_num"],
"dynamic_sampling": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
"actor_log_prob": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
"ref_log_prob": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
"actor_train": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
@ -97,6 +98,39 @@ class MMGRPOTransferDock(TransferDock):
return trans_mm_experience_to_output(experience, experience_columns)
def get_experience_dict(
self,
experience_columns: List[str],
indexes: List[int] = None,
get_n_samples: bool = True,
):
"""Get multimodal experience data from GRPOTransferDock.
Args:
experience_columns: Columns from which to get data.
indexes: Rows from which to get data.
Returns: Data dict.
"""
if indexes is None:
return {}
if get_n_samples:
indexes = indexes[::self.n_samples_per_prompt]
indexes = [i // self.n_samples_per_prompt for i in indexes]
experience = []
for single_column in experience_columns:
if len(indexes) == 1:
experience.append([self.experience_data[single_column][indexes[0]]])
else:
experience.append(list(itemgetter(*indexes)(self.experience_data[single_column])))
result = {}
for i, columns in enumerate(experience_columns):
result[columns] = experience[i]
return result
def put_experience(
self,
batch: Dict[str, Tensor],

View File

@ -154,10 +154,11 @@ class ActorHybridWorkerBase(BaseWorker):
self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role)
MsProbe.config_init(self.msprobe_config)
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
self.td = td
self.mm_td = mm_td
self.sampling_transfer_dock = sampling_transfer_dock
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
self.empty_cache()
def get_iteration(self):

View File

@ -453,7 +453,10 @@ class BaseWorker(BaseRayWorker, ABC):
client = self.zmq_client if self.zmq_client is not None else None
if is_multimodal():
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
if self.sampling_transfer_dock and is_generate:
mm_columns = ray.get(self.mm_sampling_transfer_dock.get_columns.remote(experience_consumer_stage))
else:
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
else:
mm_columns = []
@ -482,7 +485,11 @@ class BaseWorker(BaseRayWorker, ABC):
if not index: # 判断是否取出数据,未取出数据为-1
index = [-1] * experience_count
elif is_multimodal():
batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index, get_n_samples))
if self.sampling_transfer_dock and is_generate:
batch_mm_data = ray.get(self.mm_sampling_transfer_dock.get_experience.remote(mm_columns, index,
get_n_samples))
else:
batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index, get_n_samples))
if not index:
index = [-1] * experience_count
@ -678,11 +685,21 @@ class BaseWorker(BaseRayWorker, ABC):
ray.get(self.sampling_transfer_dock.put_experience.remote(data_dict=output, indexes=index))
else:
self.sampling_transfer_dock.put_experience.remote(data_dict=output, indexes=index)
if is_multimodal():
if sync:
ray.get(self.mm_sampling_transfer_dock.put_experience.remote(batch=output, indexes=index))
else:
self.mm_sampling_transfer_dock.put_experience.remote(batch=output, indexes=index)
else:
if sync:
ray.get(self.td.put_experience.remote(data_dict=output, indexes=index))
else:
self.td.put_experience.remote(data_dict=output, indexes=index)
if is_multimodal():
if sync:
ray.get(self.mm_td.put_experience.remote(batch=output, indexes=index))
else:
self.mm_td.put_experience.remote(batch=output, indexes=index)
@mstx_timer_decorator
def collect_transfer_dock_mm_data(self, output, index, use_vllm=False):

View File

@ -114,10 +114,11 @@ class CriticWorkerBase(BaseWorker):
self.critic_profiler = profiler_start(self.profiler_config, self.profiler_config.role)
MsProbe.config_init(self.msprobe_config)
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
self.td = td
self.mm_td = mm_td
self.sampling_transfer_dock = sampling_transfer_dock
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
def get_iteration(self):
return self.args.iteration

View File

@ -3,7 +3,7 @@ import ray
import numpy as np
from mindspeed_rl.utils.loggers import Loggers
from mindspeed_rl.utils.utils import get_current_dp_range_indexes, extract_from_dict
from mindspeed_rl.utils.utils import get_current_dp_range_indexes, extract_from_dict, is_multimodal
from mindspeed_rl.utils.pad_process import remove_padding_tensor_dict_to_dict, padding_dict_to_tensor_dict
logger = Loggers("DynamicSampling")
@ -18,9 +18,11 @@ class DynamicSampling(object):
self.global_batch_size = megatron_config.global_batch_size
self.guarantee_order = rl_config.guarantee_order
def init_transfer_dock(self, td, sampling_transfer_dock):
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
self.td = td
self.mm_td = mm_td
self.sampling_transfer_dock = sampling_transfer_dock
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
def dynamic_sampling(self):
experience_consumer_stage = 'dynamic_sampling'
@ -59,4 +61,9 @@ class DynamicSampling(object):
experience_data = extract_from_dict(batch_data, kept_idx_list)
experience_data = padding_dict_to_tensor_dict(experience_data)
ray.get(self.td.put_experience.remote(experience_data, index_list))
if is_multimodal():
mm_columns = ray.get(self.mm_sampling_transfer_dock.get_columns.remote(experience_consumer_stage))
batch_mm_data = ray.get(self.mm_sampling_transfer_dock.get_experience_dict.remote(mm_columns, kept_idx_list, False))
mm_index_list = [i // self.n_samples_per_prompt for i in index_list]
ray.get(self.mm_td.put_experience.remote(batch_mm_data, mm_index_list))

View File

@ -94,10 +94,11 @@ class ReferenceWorkerBase(BaseWorker):
context_parallel_size=self.megatron_config.context_parallel_size
)
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
self.td = td
self.mm_td = mm_td
self.sampling_transfer_dock = sampling_transfer_dock
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
@mstx_timer_decorator
def compute_ref_log_prob(self):

View File

@ -84,9 +84,11 @@ class RewardWorkerBase(BaseWorker):
context_parallel_size=self.megatron_config.context_parallel_size
)
def init_transfer_dock(self, td, sampling_transfer_dock=None):
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
self.td = td
self.mm_td = mm_td
self.sampling_transfer_dock = sampling_transfer_dock
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
def compute_rm_score(self):
experience_consumer_stage = 'reward_scores'

View File

@ -3,7 +3,7 @@ import ray
from transformers import AutoTokenizer
import torch
from mindspeed_rl.models.rule_verifier import compute_verifier_score, math_compute_score
from mindspeed_rl.models.rule_verifier import compute_verifier_score, math_compute_score, math_acc_reward
from mindspeed_rl.utils.loggers import Loggers
from mindspeed_rl.trainer.utils.transfer_dock import pad_experience
from mindspeed_rl.utils.pad_process import remove_padding_tensor_dict_to_dict, padding_dict_to_tensor_dict
@ -23,10 +23,11 @@ class RuleReward(object):
self.hf_tokenizer = AutoTokenizer.from_pretrained(megatron_config.tokenizer_name_or_path,
trust_remote_code=trust_remote_code)
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None):
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
self.td = td
self.mm_td = mm_td
self.sampling_transfer_dock = sampling_transfer_dock
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
def compute_rm_score(self):
experience_consumer_stage = 'rule_reward'
@ -85,8 +86,9 @@ class RuleReward(object):
output = padding_dict_to_tensor_dict(output)
cur_td.put_experience.remote(data_dict=output, indexes=index)
else:
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index))
mm_cur_td = self.mm_sampling_transfer_dock if self.mm_sampling_transfer_dock else self.mm_td
mm_columns = ray.get(mm_cur_td.get_columns.remote(experience_consumer_stage))
batch_mm_data = ray.get(mm_cur_td.get_experience.remote(mm_columns, index))
batch_data.update(batch_mm_data)
reward_tensor = torch.zeros((batch_data['responses'].size(0), 1), dtype=torch.float32)
@ -98,9 +100,12 @@ class RuleReward(object):
for label in batch_data['labels']:
labels.append(label)
metrics_score = []
for i, (response_str, label) in enumerate(zip(response_strs, labels)):
token_level_rewards = math_compute_score(response_str, label)
reward_tensor[i, 0] = token_level_rewards
metrics_score.append(int(math_acc_reward(response_str, label)))
metrics = {"acc_for_dapo_rewards/mean": metrics_score}
rm_scores = reward_tensor
reward_tensor_reshaped = reward_tensor.reshape(-1, self.n_samples_per_prompt)
reward_mean = reward_tensor_reshaped.mean(dim=1, keepdim=True)
@ -108,5 +113,10 @@ class RuleReward(object):
reward_tensor_normalized = (reward_tensor_reshaped - reward_mean) / reward_std
reward_tensor = reward_tensor_normalized.reshape(original_shape)
output = {"rm_scores": rm_scores, "token_level_rewards": reward_tensor}
if self.rl_config.filter_groups_enable:
metric = torch.tensor(metrics[self.rl_config.filter_groups_metric], dtype=torch.float32,
device=rm_scores.device)
metric = metric.reshape(rm_scores.shape)
output["metric_for_dapo"] = metric
output = padding_dict_to_tensor_dict(output)
self.td.put_experience.remote(data_dict=output, indexes=index)
cur_td.put_experience.remote(data_dict=output, indexes=index)

View File

@ -258,9 +258,9 @@ class RayActorGroup:
for actor in self.actor_handlers:
self.temp_actor_ref_objs.append(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock))
def sync_init_transfer_dock(self, transfer_dock, mm_transfer_dock=None, sampling_transfer_dock=None):
def sync_init_transfer_dock(self, transfer_dock, mm_transfer_dock=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
for actor in self.actor_handlers:
ray.get(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock, sampling_transfer_dock))
ray.get(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock, sampling_transfer_dock, mm_sampling_transfer_dock))
def enter_infer_mode(self, blocking=False):
for actor in self.actor_handlers:

View File

@ -212,9 +212,9 @@ class RayActorGroupMs:
for actor in self.actor_handlers:
self.temp_actor_ref_objs.append(actor.init_transfer_dock.remote(transfer_dock))
def sync_init_transfer_dock(self, transfer_dock, sampling_transfer_dock=None):
def sync_init_transfer_dock(self, transfer_dock, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
for actor in self.actor_handlers:
ray.get(actor.init_transfer_dock.remote(transfer_dock, sampling_transfer_dock))
ray.get(actor.init_transfer_dock.remote(transfer_dock, sampling_transfer_dock, mm_sampling_transfer_dock))
def enter_infer_mode(self):
for actor in self.actor_handlers: