mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
!578 【多模态】【feat.】Supports DAPO
Merge pull request !578 from wanghao/dapo
This commit is contained in:
@ -8,12 +8,13 @@ from mindspeed_rl.utils.tokenizer import BaseTokenizer
|
||||
from mindspeed_rl.workers.dynamic_sampling import DynamicSampling
|
||||
from mindspeed_rl.workers.rule_reward import RuleReward
|
||||
from mindspeed_rl.trainer.base import RayBaseTrainer
|
||||
from mindspeed_rl.trainer.utils import MMGRPOTransferDock
|
||||
from mindspeed_rl.trainer.utils.transfer_dock import GRPOTransferDock, put_prompts_experience
|
||||
from mindspeed_rl.trainer.utils.compute_utils import compute_advantage, compute_dapo_data_metrics
|
||||
from mindspeed_rl.workers.scheduler.launcher import RayActorGroup
|
||||
from mindspeed_rl.utils.loggers import Loggers
|
||||
from mindspeed_rl.utils.metrics import Metric
|
||||
from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort
|
||||
from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort, is_multimodal
|
||||
|
||||
|
||||
class RayDAPOTrainer(RayBaseTrainer):
|
||||
@ -87,6 +88,8 @@ class RayDAPOTrainer(RayBaseTrainer):
|
||||
|
||||
self.transfer_dock = None
|
||||
self.sampling_transfer_dock = None
|
||||
self.mm_transfer_dock = None
|
||||
self.mm_sampling_transfer_dock = None
|
||||
self.metrics = Metric()
|
||||
self.kwargs = kwargs
|
||||
self.should_filter = self.kwargs['filter_groups_enable']
|
||||
@ -121,24 +124,28 @@ class RayDAPOTrainer(RayBaseTrainer):
|
||||
metrics=self.metrics,
|
||||
addition_columns=self.addition_columns,
|
||||
addition_consumers=self.addition_consumers)
|
||||
if is_multimodal():
|
||||
self.mm_transfer_dock = MMGRPOTransferDock.remote(self.max_num_prompt_in_batch, self.n_samples_per_prompt)
|
||||
self.mm_sampling_transfer_dock = MMGRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt)
|
||||
for sampling in self.dynamic_sampling_list:
|
||||
sampling.init_transfer_dock.remote(self.transfer_dock,
|
||||
sampling_transfer_dock=self.sampling_transfer_dock)
|
||||
sampling.init_transfer_dock.remote(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
|
||||
else:
|
||||
self.transfer_dock = GRPOTransferDock.remote(self.td_max_len, self.n_samples_per_prompt,
|
||||
max_age=self.partial_rollout_max_split,
|
||||
GBS_train=self.global_batch_size,
|
||||
metrics=self.metrics, addition_columns=self.addition_columns,
|
||||
addition_consumers=self.addition_consumers)
|
||||
if is_multimodal():
|
||||
self.mm_transfer_dock = MMGRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt)
|
||||
|
||||
self.actor_worker.sync_init_transfer_dock(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
|
||||
self.actor_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
|
||||
if self.ref_worker:
|
||||
self.ref_worker.sync_init_transfer_dock(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
|
||||
self.ref_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
|
||||
for reward in self.reward_list:
|
||||
if hasattr(reward, 'sync_init_transfer_dock'):
|
||||
reward.sync_init_transfer_dock(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
|
||||
reward.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
|
||||
else:
|
||||
reward.init_transfer_dock.remote(self.transfer_dock, sampling_transfer_dock=self.sampling_transfer_dock)
|
||||
reward.init_transfer_dock.remote(self.transfer_dock, self.mm_transfer_dock, self.sampling_transfer_dock, self.mm_sampling_transfer_dock)
|
||||
|
||||
def set_actor_log_prob_skip_flag(self):
|
||||
if self.should_filter:
|
||||
@ -155,10 +162,16 @@ class RayDAPOTrainer(RayBaseTrainer):
|
||||
ray.get(self.sampling_transfer_dock.clear.remote(consumer='dynamic_sampling'))
|
||||
index_list = ray.get(self.sampling_transfer_dock.prefetch_request_index.remote(data_num))
|
||||
if index_list:
|
||||
if is_multimodal():
|
||||
ray.get(self.mm_sampling_transfer_dock.clear.remote())
|
||||
ray.get(self.mm_sampling_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
|
||||
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys,
|
||||
indexes=index_list, add_another_batch=add_another_batch)
|
||||
ray.get(self.sampling_transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
|
||||
else:
|
||||
if is_multimodal():
|
||||
ray.get(self.mm_transfer_dock.clear.remote())
|
||||
ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
|
||||
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys,
|
||||
add_another_batch=add_another_batch)
|
||||
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
|
||||
|
@ -43,6 +43,7 @@ class MMGRPOTransferDock(TransferDock):
|
||||
self.n_samples_per_prompt = n_samples_per_prompt
|
||||
self.consumer_columns = {
|
||||
"actor_rollout": ["image", "image_shape", "image_num", "video", "video_shape", "video_fps", "video_num"],
|
||||
"dynamic_sampling": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
|
||||
"actor_log_prob": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
|
||||
"ref_log_prob": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
|
||||
"actor_train": ["pixel_values", "image_grid_thw", "image_num", "video_num"],
|
||||
@ -97,6 +98,39 @@ class MMGRPOTransferDock(TransferDock):
|
||||
|
||||
return trans_mm_experience_to_output(experience, experience_columns)
|
||||
|
||||
def get_experience_dict(
|
||||
self,
|
||||
experience_columns: List[str],
|
||||
indexes: List[int] = None,
|
||||
get_n_samples: bool = True,
|
||||
):
|
||||
"""Get multimodal experience data from GRPOTransferDock.
|
||||
|
||||
Args:
|
||||
experience_columns: Columns from which to get data.
|
||||
indexes: Rows from which to get data.
|
||||
|
||||
Returns: Data dict.
|
||||
|
||||
"""
|
||||
if indexes is None:
|
||||
return {}
|
||||
|
||||
if get_n_samples:
|
||||
indexes = indexes[::self.n_samples_per_prompt]
|
||||
|
||||
indexes = [i // self.n_samples_per_prompt for i in indexes]
|
||||
experience = []
|
||||
for single_column in experience_columns:
|
||||
if len(indexes) == 1:
|
||||
experience.append([self.experience_data[single_column][indexes[0]]])
|
||||
else:
|
||||
experience.append(list(itemgetter(*indexes)(self.experience_data[single_column])))
|
||||
result = {}
|
||||
for i, columns in enumerate(experience_columns):
|
||||
result[columns] = experience[i]
|
||||
return result
|
||||
|
||||
def put_experience(
|
||||
self,
|
||||
batch: Dict[str, Tensor],
|
||||
|
@ -154,10 +154,11 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role)
|
||||
MsProbe.config_init(self.msprobe_config)
|
||||
|
||||
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):
|
||||
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
self.td = td
|
||||
self.mm_td = mm_td
|
||||
self.sampling_transfer_dock = sampling_transfer_dock
|
||||
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
|
||||
self.empty_cache()
|
||||
|
||||
def get_iteration(self):
|
||||
|
@ -453,7 +453,10 @@ class BaseWorker(BaseRayWorker, ABC):
|
||||
client = self.zmq_client if self.zmq_client is not None else None
|
||||
|
||||
if is_multimodal():
|
||||
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
|
||||
if self.sampling_transfer_dock and is_generate:
|
||||
mm_columns = ray.get(self.mm_sampling_transfer_dock.get_columns.remote(experience_consumer_stage))
|
||||
else:
|
||||
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
|
||||
else:
|
||||
mm_columns = []
|
||||
|
||||
@ -482,7 +485,11 @@ class BaseWorker(BaseRayWorker, ABC):
|
||||
if not index: # 判断是否取出数据,未取出数据为-1
|
||||
index = [-1] * experience_count
|
||||
elif is_multimodal():
|
||||
batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index, get_n_samples))
|
||||
if self.sampling_transfer_dock and is_generate:
|
||||
batch_mm_data = ray.get(self.mm_sampling_transfer_dock.get_experience.remote(mm_columns, index,
|
||||
get_n_samples))
|
||||
else:
|
||||
batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index, get_n_samples))
|
||||
|
||||
if not index:
|
||||
index = [-1] * experience_count
|
||||
@ -678,11 +685,21 @@ class BaseWorker(BaseRayWorker, ABC):
|
||||
ray.get(self.sampling_transfer_dock.put_experience.remote(data_dict=output, indexes=index))
|
||||
else:
|
||||
self.sampling_transfer_dock.put_experience.remote(data_dict=output, indexes=index)
|
||||
if is_multimodal():
|
||||
if sync:
|
||||
ray.get(self.mm_sampling_transfer_dock.put_experience.remote(batch=output, indexes=index))
|
||||
else:
|
||||
self.mm_sampling_transfer_dock.put_experience.remote(batch=output, indexes=index)
|
||||
else:
|
||||
if sync:
|
||||
ray.get(self.td.put_experience.remote(data_dict=output, indexes=index))
|
||||
else:
|
||||
self.td.put_experience.remote(data_dict=output, indexes=index)
|
||||
if is_multimodal():
|
||||
if sync:
|
||||
ray.get(self.mm_td.put_experience.remote(batch=output, indexes=index))
|
||||
else:
|
||||
self.mm_td.put_experience.remote(batch=output, indexes=index)
|
||||
|
||||
@mstx_timer_decorator
|
||||
def collect_transfer_dock_mm_data(self, output, index, use_vllm=False):
|
||||
|
@ -114,10 +114,11 @@ class CriticWorkerBase(BaseWorker):
|
||||
self.critic_profiler = profiler_start(self.profiler_config, self.profiler_config.role)
|
||||
MsProbe.config_init(self.msprobe_config)
|
||||
|
||||
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):
|
||||
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
self.td = td
|
||||
self.mm_td = mm_td
|
||||
self.sampling_transfer_dock = sampling_transfer_dock
|
||||
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
|
||||
|
||||
def get_iteration(self):
|
||||
return self.args.iteration
|
||||
|
@ -3,7 +3,7 @@ import ray
|
||||
import numpy as np
|
||||
|
||||
from mindspeed_rl.utils.loggers import Loggers
|
||||
from mindspeed_rl.utils.utils import get_current_dp_range_indexes, extract_from_dict
|
||||
from mindspeed_rl.utils.utils import get_current_dp_range_indexes, extract_from_dict, is_multimodal
|
||||
from mindspeed_rl.utils.pad_process import remove_padding_tensor_dict_to_dict, padding_dict_to_tensor_dict
|
||||
|
||||
logger = Loggers("DynamicSampling")
|
||||
@ -18,9 +18,11 @@ class DynamicSampling(object):
|
||||
self.global_batch_size = megatron_config.global_batch_size
|
||||
self.guarantee_order = rl_config.guarantee_order
|
||||
|
||||
def init_transfer_dock(self, td, sampling_transfer_dock):
|
||||
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
self.td = td
|
||||
self.mm_td = mm_td
|
||||
self.sampling_transfer_dock = sampling_transfer_dock
|
||||
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
|
||||
|
||||
def dynamic_sampling(self):
|
||||
experience_consumer_stage = 'dynamic_sampling'
|
||||
@ -59,4 +61,9 @@ class DynamicSampling(object):
|
||||
experience_data = extract_from_dict(batch_data, kept_idx_list)
|
||||
experience_data = padding_dict_to_tensor_dict(experience_data)
|
||||
ray.get(self.td.put_experience.remote(experience_data, index_list))
|
||||
if is_multimodal():
|
||||
mm_columns = ray.get(self.mm_sampling_transfer_dock.get_columns.remote(experience_consumer_stage))
|
||||
batch_mm_data = ray.get(self.mm_sampling_transfer_dock.get_experience_dict.remote(mm_columns, kept_idx_list, False))
|
||||
mm_index_list = [i // self.n_samples_per_prompt for i in index_list]
|
||||
ray.get(self.mm_td.put_experience.remote(batch_mm_data, mm_index_list))
|
||||
|
||||
|
@ -94,10 +94,11 @@ class ReferenceWorkerBase(BaseWorker):
|
||||
context_parallel_size=self.megatron_config.context_parallel_size
|
||||
)
|
||||
|
||||
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):
|
||||
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
self.td = td
|
||||
self.mm_td = mm_td
|
||||
self.sampling_transfer_dock = sampling_transfer_dock
|
||||
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
|
||||
|
||||
@mstx_timer_decorator
|
||||
def compute_ref_log_prob(self):
|
||||
|
@ -84,9 +84,11 @@ class RewardWorkerBase(BaseWorker):
|
||||
context_parallel_size=self.megatron_config.context_parallel_size
|
||||
)
|
||||
|
||||
def init_transfer_dock(self, td, sampling_transfer_dock=None):
|
||||
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
self.td = td
|
||||
self.mm_td = mm_td
|
||||
self.sampling_transfer_dock = sampling_transfer_dock
|
||||
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
|
||||
|
||||
def compute_rm_score(self):
|
||||
experience_consumer_stage = 'reward_scores'
|
||||
|
@ -3,7 +3,7 @@ import ray
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
|
||||
from mindspeed_rl.models.rule_verifier import compute_verifier_score, math_compute_score
|
||||
from mindspeed_rl.models.rule_verifier import compute_verifier_score, math_compute_score, math_acc_reward
|
||||
from mindspeed_rl.utils.loggers import Loggers
|
||||
from mindspeed_rl.trainer.utils.transfer_dock import pad_experience
|
||||
from mindspeed_rl.utils.pad_process import remove_padding_tensor_dict_to_dict, padding_dict_to_tensor_dict
|
||||
@ -23,10 +23,11 @@ class RuleReward(object):
|
||||
self.hf_tokenizer = AutoTokenizer.from_pretrained(megatron_config.tokenizer_name_or_path,
|
||||
trust_remote_code=trust_remote_code)
|
||||
|
||||
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None):
|
||||
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
self.td = td
|
||||
self.mm_td = mm_td
|
||||
self.sampling_transfer_dock = sampling_transfer_dock
|
||||
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
|
||||
|
||||
def compute_rm_score(self):
|
||||
experience_consumer_stage = 'rule_reward'
|
||||
@ -85,8 +86,9 @@ class RuleReward(object):
|
||||
output = padding_dict_to_tensor_dict(output)
|
||||
cur_td.put_experience.remote(data_dict=output, indexes=index)
|
||||
else:
|
||||
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
|
||||
batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index))
|
||||
mm_cur_td = self.mm_sampling_transfer_dock if self.mm_sampling_transfer_dock else self.mm_td
|
||||
mm_columns = ray.get(mm_cur_td.get_columns.remote(experience_consumer_stage))
|
||||
batch_mm_data = ray.get(mm_cur_td.get_experience.remote(mm_columns, index))
|
||||
batch_data.update(batch_mm_data)
|
||||
|
||||
reward_tensor = torch.zeros((batch_data['responses'].size(0), 1), dtype=torch.float32)
|
||||
@ -98,9 +100,12 @@ class RuleReward(object):
|
||||
for label in batch_data['labels']:
|
||||
labels.append(label)
|
||||
|
||||
metrics_score = []
|
||||
for i, (response_str, label) in enumerate(zip(response_strs, labels)):
|
||||
token_level_rewards = math_compute_score(response_str, label)
|
||||
reward_tensor[i, 0] = token_level_rewards
|
||||
metrics_score.append(int(math_acc_reward(response_str, label)))
|
||||
metrics = {"acc_for_dapo_rewards/mean": metrics_score}
|
||||
rm_scores = reward_tensor
|
||||
reward_tensor_reshaped = reward_tensor.reshape(-1, self.n_samples_per_prompt)
|
||||
reward_mean = reward_tensor_reshaped.mean(dim=1, keepdim=True)
|
||||
@ -108,5 +113,10 @@ class RuleReward(object):
|
||||
reward_tensor_normalized = (reward_tensor_reshaped - reward_mean) / reward_std
|
||||
reward_tensor = reward_tensor_normalized.reshape(original_shape)
|
||||
output = {"rm_scores": rm_scores, "token_level_rewards": reward_tensor}
|
||||
if self.rl_config.filter_groups_enable:
|
||||
metric = torch.tensor(metrics[self.rl_config.filter_groups_metric], dtype=torch.float32,
|
||||
device=rm_scores.device)
|
||||
metric = metric.reshape(rm_scores.shape)
|
||||
output["metric_for_dapo"] = metric
|
||||
output = padding_dict_to_tensor_dict(output)
|
||||
self.td.put_experience.remote(data_dict=output, indexes=index)
|
||||
cur_td.put_experience.remote(data_dict=output, indexes=index)
|
||||
|
@ -258,9 +258,9 @@ class RayActorGroup:
|
||||
for actor in self.actor_handlers:
|
||||
self.temp_actor_ref_objs.append(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock))
|
||||
|
||||
def sync_init_transfer_dock(self, transfer_dock, mm_transfer_dock=None, sampling_transfer_dock=None):
|
||||
def sync_init_transfer_dock(self, transfer_dock, mm_transfer_dock=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
for actor in self.actor_handlers:
|
||||
ray.get(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock, sampling_transfer_dock))
|
||||
ray.get(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock, sampling_transfer_dock, mm_sampling_transfer_dock))
|
||||
|
||||
def enter_infer_mode(self, blocking=False):
|
||||
for actor in self.actor_handlers:
|
||||
|
@ -212,9 +212,9 @@ class RayActorGroupMs:
|
||||
for actor in self.actor_handlers:
|
||||
self.temp_actor_ref_objs.append(actor.init_transfer_dock.remote(transfer_dock))
|
||||
|
||||
def sync_init_transfer_dock(self, transfer_dock, sampling_transfer_dock=None):
|
||||
def sync_init_transfer_dock(self, transfer_dock, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
|
||||
for actor in self.actor_handlers:
|
||||
ray.get(actor.init_transfer_dock.remote(transfer_dock, sampling_transfer_dock))
|
||||
ray.get(actor.init_transfer_dock.remote(transfer_dock, sampling_transfer_dock, mm_sampling_transfer_dock))
|
||||
|
||||
def enter_infer_mode(self):
|
||||
for actor in self.actor_handlers:
|
||||
|
Reference in New Issue
Block a user