!507 partial rollout

Merge pull request !507 from panchenyi/master
This commit is contained in:
panchenyi
2025-08-04 11:07:37 +00:00
committed by i-robot
parent 9958cb3c40
commit 5cc9a136a8
20 changed files with 547 additions and 154 deletions

View File

@ -4,7 +4,14 @@
MindSpeed RL是基于昇腾生态的强化学习加速框架旨在为华为 [昇腾芯片](https://www.hiascend.com/) 生态合作伙伴提供端到端的RL训推解决方案支持超大昇腾集群训推共卡/分离部署、多模型异步流水调度、训推异构切分通信等核心加速能力
---
## NEWS
🚀🚀🚀 Partial Rollout功能已支持🚀🚀🚀
---
## 安装指南
@ -195,6 +202,15 @@ MindSpeed RL是基于昇腾生态的强化学习加速框架旨在为华为 [
</td>
<td> Preview</td>
</tr>
<tr>
<td>Partial Rollout</td>
<td><a href="docs/features/partial_rollout.md">Doc</a></td>
<td rowspan="1">
GRPO <br>
</td>
</td>
<td> Preview</td>
</tr>
</tbody>
</table>

View File

@ -0,0 +1,45 @@
# partial rollout
## 简介
Partial rollout 核心思想是通过对长序列 response 推理样本做提前中断,并在下次推理过程中对当前样本进行续推,从而避免单一的长尾样本对推理过程造成资源浪费。通过该能力,我们可以降低长序列推理场景下的长尾样本对端到端性能的影响。
## 使用方法
```yaml
rl_config:
partial_rollout_max_split: N # 设置N>1即可在N轮内完全推理完成最长序列
```
## 技术方案
### 同步推理引擎方案
核心理念:断点续推+跨迭代长尾调度避免推理资源闲置
同步引擎:数据按批处理,同时进入推理引擎、批次内所有数据完成推理后同时返回结果
关键技术点:
1. 长序列推理截断机制根据最大推理长度和次数设置推理截断点将截断样本放入TransferDock当满足≥GBS个prompt已完成全部推理则进入后续计算任务否则则从TransferDock中取数据再次推理达成高资源利用率。
2. 基于优先级的混合A数据重排和采样技术在下一轮推理时优先取出被截断样本进行推理避免影响效果和收敛性。
![img.png](../../sources/images/partial_rollout/sync.png)
图1 同步引擎方案示意图
![img_1.png](../../sources/images/partial_rollout/sync_1.png)
图2 同步引擎流程图
### 异步推理引擎方案
核心理念:断点续推+跨迭代长尾调度避免推理资源闲置
异步引擎:数据按批次进入推理引擎,可异步按样本粒度返回结果
关键技术点:
1. 实时长序列推理截断机制实现与推理引擎交互动态确定长尾序列被截断长度当满足≥GBS个prompt已完成全部推理则中断推理过程将截断样本放入TransferDock避免长尾序列推理拖慢整体推理时间、造成资源空置。
2. 基于优先级的混合数据重排和采样技术:在下一轮推理时,优先取出被截断样本并混合新样本进行推理。
3. 收敛性和稳定性保证:实现样本在规定的迭代轮数内完成推理。
![img_2.png](../../sources/images/partial_rollout/async.png)
图3 异步引擎方案示意图
![img_3.png](../../sources/images/partial_rollout/async_1.png)
图4 异步引擎流程图
## 验证情况
![img_4.png](../../sources/images/partial_rollout/sync_partial_compare_result.png)
图5 同步引擎验证结果

View File

@ -57,6 +57,8 @@ class RLConfig(BaseConfig):
use_dp_batch_balance: Whether to use dynamic batch size balancing across data parallel ranks (default: False)
# Default values can still be defined if no config is provided
use_remove_padding: Whether to use packed sequences for forward (default: False)
partial_rollout_max_split: The multiple of token splitting for max tokens when partial rollout is enabled. (default: 1)
require_max_age_all_finished: wherther to require the reponses that have reached max_age must be completed in this iteration or can be incomplete (default: True)
'''
def __init__(self, config_dict):
@ -150,6 +152,10 @@ class RLConfig(BaseConfig):
self.filter_groups_max_batches = 1
self.filter_groups_train_batch_size = 1
self.partial_rollout_max_split = 1
self.require_max_age_all_finished = True
if config_dict.get("actor_resource") is not None:
for key, _ in config_dict["actor_resource"].items():
if key not in self.actor_resource:

View File

@ -2,16 +2,12 @@ from typing import Dict, Callable, Optional, Any
import numpy as np
import torch
from torch.utils.data import DataLoader
from mindspeed_rl.datasets.utils import _infer_seqlen, get_prompt_index
from mindspeed_rl.datasets.indexed_dataset import get_packed_indexed_dataset
from mindspeed_rl.datasets.base_dataset import BaseDataset
from mindspeed_rl.datasets.templates import get_model_template
from mindspeed_rl.datasets.utils import _build_index_mappings
from mindspeed_rl.datasets.data_samplers import PromptSampler
class PromptDataset(BaseDataset):

View File

@ -84,21 +84,18 @@ class ActorRolloutHybrid(ABC):
self,
prompts_list: List[List[int]],
indexes=None,
n_samples_per_prompt=None,
async_engine=False,
max_tokens=128,
stop_singal_func=None,
**kwargs) -> Tensor:
if async_engine:
res = self.inference_actor.async_generate_sequences(
prompts_list,
indexes,
n_samples_per_prompt=n_samples_per_prompt,
max_tokens=max_tokens,
**kwargs
stop_singal_func=stop_singal_func,
**kwargs,
)
else:
res = self.inference_actor.generate_sequences(prompts_list, **kwargs)[0]
return res
@mstx_timer_decorator

View File

@ -330,7 +330,8 @@ class VLLMInferEngine(BaseInferEngine):
@torch.no_grad()
def async_generate_sequences(self, idx_list, indexes, n_samples_per_prompt=None, **kwargs):
def async_generate_sequences(self, idx_list, indexes, stop_singal_func=None, **kwargs):
STOP_SIGNAL = None
with self.update_sampling_params(**kwargs):
for i, prompt_token_ids in enumerate(idx_list):
request_id = f"req_{indexes[i]}_{uuid.uuid4().hex[:6]}"
@ -340,17 +341,23 @@ class VLLMInferEngine(BaseInferEngine):
prompt={"prompt_token_ids": prompt_token_ids},
params=self.sampling_params
)
count = 0
while self.engine.has_unfinished_requests():
count += 1
if stop_singal_func is not None and count % 20 == 0:
STOP_SIGNAL = stop_singal_func()
step_outputs = self.engine.step()
for output in step_outputs:
if output.finished:
if output.finished or STOP_SIGNAL:
request_id = output.request_id
index = int(request_id.split("_")[1])
prompt_ids = [torch.tensor(idx_list[indexes.index(index)])]
index = [index]
response_ids = self._post_process_outputs([output])
yield (prompt_ids, *response_ids), index
if STOP_SIGNAL:
self.engine.abort_request([request_id])
def _post_process_outputs(self, request_outputs):

View File

@ -42,6 +42,7 @@ class RayBaseTrainer(object):
guarantee_order: bool = False,
use_dp_batch_balance: bool = False,
num_cpus_for_local_task: float = 0.1,
partial_rollout_max_split: int = 1,
use_kl_in_reward: bool = False,
**kwargs):
@ -70,6 +71,7 @@ class RayBaseTrainer(object):
self.guarantee_order = guarantee_order
self.use_dp_batch_balance = use_dp_batch_balance
self.num_cpus_for_local_task = num_cpus_for_local_task
self.partial_rollout_max_split = partial_rollout_max_split
self.use_kl_in_reward = use_kl_in_reward
self.kwargs = kwargs

View File

@ -44,7 +44,6 @@ class RayGRPOTrainer(RayBaseTrainer):
num_cpus_for_local_task: int = 1 Number of CPUs for local ray task.
**kwargs: Additional parameters for base class argument passing.
"""
def __init__(
self,
actor_worker: RayActorGroup,
@ -65,6 +64,7 @@ class RayGRPOTrainer(RayBaseTrainer):
blocking: bool = False,
guarantee_order: bool = False,
num_cpus_for_local_task: int = 1,
partial_rollout_max_split: int = 1,
**kwargs
):
super().__init__(
@ -86,21 +86,29 @@ class RayGRPOTrainer(RayBaseTrainer):
blocking=blocking,
guarantee_order=guarantee_order,
num_cpus_for_local_task=num_cpus_for_local_task,
partial_rollout_max_split=partial_rollout_max_split,
**kwargs
)
self.transfer_dock = None
self.mm_transfer_dock = None
self.enable_partial_rollout = self.partial_rollout_max_split > 1
self.metrics = Metric()
if self.enable_partial_rollout:
self.td_max_len = self.global_batch_size * 2
else:
self.td_max_len = self.global_batch_size
self.transfer_dock_init()
self.kwargs = kwargs
self.set_actor_log_prob_skip_flag()
def transfer_dock_init(self):
self.transfer_dock = GRPOTransferDock.remote(
self.global_batch_size,
self.n_samples_per_prompt,
self.metrics,
prompts_num=self.td_max_len, # max sample num
n_samples_per_prompt=self.n_samples_per_prompt,
metrics=self.metrics,
max_age=self.partial_rollout_max_split,
GBS_train=self.global_batch_size, # GBS_train
addition_columns=self.dataset_additional_keys
)
if is_multimodal():
@ -135,20 +143,30 @@ class RayGRPOTrainer(RayBaseTrainer):
logger.info('sync start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
else:
logger.info('async start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
if self.enable_partial_rollout:
first_batch = next(data_iters)
batch, indexes = put_prompts_experience(first_batch, self.n_samples_per_prompt,
self.dataset_additional_keys)
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
logger.info(f'training start, put first batch')
while iteration < self.train_iters:
ray.get(self.transfer_dock.clear.remote())
last_iter = iteration == self.train_iters - 1
with Timer(name='iteration', logger=None) as all_timer:
batch = next(data_iters)
if self.enable_partial_rollout:
if not last_iter: # and batch is not None: # None?
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt,
self.dataset_additional_keys,
add_another_batch=True)
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
else:
batch_dict, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys)
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch_dict, indexes=indexes))
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch_dict, indexes=indexes, is_prompt=True))
if is_multimodal():
ray.get(self.mm_transfer_dock.clear.remote())
ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
with Timer(name='iteration', logger=None) as all_timer:
# generate sequences
self.actor_worker.generate_sequences(blocking=self.blocking)
# compute rm scores.
@ -158,9 +176,10 @@ class RayGRPOTrainer(RayBaseTrainer):
reward_worker.compute_rm_score(blocking=self.blocking)
else:
rule_reward.append(reward_worker.compute_rm_score.remote())
ray.get(rule_reward)
# compute advantages, executed on the driver process
self.compute_advantage(blocking=False, guarantee_order=self.guarantee_order)
self.compute_advantage(blocking=True, guarantee_order=self.guarantee_order)
# compute reference log_prob
self.ref_worker.compute_ref_log_prob(blocking=self.blocking)
@ -203,6 +222,7 @@ class RayGRPOTrainer(RayBaseTrainer):
metrics.update("vllm_tps", vllm_tps)
iteration += 1
logger.info(metrics.metric, iteration, self.train_iters)
ray.get(self.transfer_dock.clear.remote())
if self.tensorboard is not None:
for k, v in metrics.metric.items():
self.tensorboard.add_scalar(f"train/{k}", v, iteration)

View File

@ -118,9 +118,10 @@ class TransferDock(ABC):
)
for column_idx, single_column in enumerate(experience_columns):
for i, index in enumerate(indexes):
self.experience_data[single_column][index] = experience[column_idx][i]
self.experience_data_status[single_column][index] = 1
for i, idx in enumerate(indexes):
if idx >= 0:
self.experience_data[single_column][idx] = experience[column_idx][i]
self.experience_data_status[single_column][idx] = 1
def _get(self, experience_columns: List[str], indexes: List[int]):
"""Get data based on row and column numbers.
@ -188,7 +189,7 @@ class TransferDock(ABC):
elapsed_time > self.timeout
and elapsed_time % self.timeout_interval < 0.1
):
logger.warning(f"TIMEOUT: data_ready has slept {elapsed_time} second")
logger.info(f"TIMEOUT: data_ready has slept {elapsed_time} second, because {single_column} not ready")
time.sleep(0.1)
if len(indexes) == 1:
data_ready = self.experience_data_status[single_column][indexes] == 1
@ -256,6 +257,8 @@ class GRPOTransferDock(TransferDock):
prompts_num: int,
n_samples_per_prompt: int,
metrics=None,
max_age: int = 1,
GBS_train: int = 0,
addition_columns: Union[List[str], None] = None,
addition_consumers: Union[List[str], None] = None,
timeout: Union[int, None] = None,
@ -331,6 +334,16 @@ class GRPOTransferDock(TransferDock):
key: threading.Lock()
for key in self.experience_consumers
}
self.max_age = max_age
self.GBS_train = GBS_train
self.rollout_completed = torch.zeros(self.max_len, dtype=torch.int32) # 标志当前样本是否完成rollouteod || max_tokens
self.age = torch.zeros(self.max_len, dtype=torch.int32) # 落后当前actor参数的训练步数是否需要按age排序age的更新需要在TD逐出和重排序的时候做
self.enable_partial_rollout = max_age > 1 # max_age = 1 是续推0次因为rollout_completed的判断是在TD外面做的
if self.enable_partial_rollout:
self.stop_partial_rollout_signal = False
self.global_ready_mask = torch.zeros(self.max_len, dtype=torch.int32)
self.metrics = metrics
self.prefetch_request_index_lock = threading.Lock()
self.cur_index = 0
@ -365,6 +378,7 @@ class GRPOTransferDock(TransferDock):
consumer: str,
experience_columns: List[str],
experience_count: int = None,
dp_size: int = 1,
indexes: List[int] = None,
get_n_samples: bool = True,
use_batch_seqlen_balance: bool = False
@ -392,9 +406,18 @@ class GRPOTransferDock(TransferDock):
for experience_column in experience_columns:
if experience_column not in self.experience_columns:
if experience_column != 'age':
raise ValueError(
f"get experience ERROR: {experience_column} not in TD experience_column {self.experience_columns}"
)
elif consumer == 'actor_rollout' and self.enable_partial_rollout:
experience_columns.remove('age')
if consumer == "actor_rollout" and self.enable_partial_rollout:
if get_n_samples:
raise ValueError(
"get_n_samples not supported for rollout when actor_rollout enables partial_rollout"
)
if indexes is None:
if experience_count > self.max_len:
@ -402,7 +425,7 @@ class GRPOTransferDock(TransferDock):
f"TD max_len: {self.max_len} need >= experience_count: {experience_count}"
)
if self.max_len % experience_count != 0:
if self.max_len % experience_count != 0 and not self.enable_partial_rollout:
raise ValueError(
f"TD max_len:{self.max_len} need be divisible by experience_count: {experience_count}"
)
@ -430,6 +453,22 @@ class GRPOTransferDock(TransferDock):
self.experience_consumer_status[consumer][indexes] = 1
experience = self._get(experience_columns, indexes)
if consumer == "actor_rollout" and self.enable_partial_rollout:
experience_columns.append('age')
age_list = [torch.tensor([i]) for i in self.age[indexes]]
experience.append(age_list)
## 状态量都在取sample时刷新
self.experience_data_status["responses"][indexes] = 0
self.experience_data_status["response_length"][indexes] = 0
sample_num = len(indexes)
if sample_num < experience_count and sample_num > 0:
min_dp_size_multiple = ((sample_num + dp_size - 1) // dp_size) * dp_size
indexes_extend = indexes + [-2] * (min_dp_size_multiple - sample_num)
for col, _ in enumerate(experience):
for _, _ in enumerate(indexes_extend[sample_num:]):
experience[col].append(experience[col][sample_num - 1]) # 重复最后一条样本
indexes = indexes_extend
experience_batch = {}
for i, experience_column in enumerate(experience_columns):
experience_batch[experience_column] = experience[i]
@ -440,6 +479,7 @@ class GRPOTransferDock(TransferDock):
self,
data_dict: Dict[str, Union[Tensor, List[Tensor]]],
indexes: List[int] = None,
is_prompt: bool = False
):
"""Put data into specified columns and rows.
@ -456,9 +496,38 @@ class GRPOTransferDock(TransferDock):
"put experience into TD without indexes, indexes must be provided"
)
data_dict = remove_padding_tensor_dict_to_dict(data_dict)
if self.enable_partial_rollout and self.GBS_train == 0:
raise ValueError("GBS for update must be provided when enabling partial rollout")
if self.enable_partial_rollout and 'responses' in data_dict.keys():
if 'rollout_completed' not in data_dict.keys():
raise ValueError(
"partial rollout enabled, when putting responses, rollout_completed status must be provided in data dict"
)
experience_columns, experience = trans_input_to_experience(data_dict)
if "responses" in experience_columns: # 确定是rollout阶段
if self.enable_partial_rollout: # 确定partial rollout功能开启
rollout_completed_col_id = experience_columns.index('rollout_completed')
rollout_completed_column = experience.pop(rollout_completed_col_id)
experience_columns.pop(rollout_completed_col_id)
for i, idx in enumerate(indexes):
if idx >= 0:
if rollout_completed_column[i][0] == 1:
self.rollout_completed[idx] = 1
self._put(experience_columns, experience, indexes)
# _get之后会刷新角色消费状态所以需要再更新一下
if ("responses" in experience_columns) and self.enable_partial_rollout:
self.experience_consumer_status['actor_rollout'][indexes] = copy.deepcopy(self.rollout_completed[indexes])
if self.enable_partial_rollout and is_prompt:
self.experience_data_status['responses'][indexes] = 1
self.experience_data_status['response_length'][indexes] = 1
for i in indexes:
self.experience_data['responses'][i] = torch.tensor([-1], dtype=torch.int32)
self.experience_data['response_length'][i] = torch.tensor([0], dtype=torch.int32)
def _sample_ready_index(
self,
@ -486,15 +555,26 @@ class GRPOTransferDock(TransferDock):
[self.experience_data_status[single_column] == 1 for single_column in experience_columns]
), dim=0,
)
if self.enable_partial_rollout and consumer != 'actor_rollout':
update_ready_indexes = self.global_ready_mask == 1
usable_indexes = (not_consumed_indexes & data_ready_indexes & update_ready_indexes).nonzero(as_tuple=True)[0]
else:
usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0]
if len(usable_indexes) < experience_count:
if self.enable_partial_rollout and consumer == 'actor_rollout' and len(usable_indexes) > 0:
experience_count = len(usable_indexes)
else:
return None
if experience_count <= 0:
return None
if consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
if self.enable_partial_rollout and consumer == 'actor_rollout':
sampled_indexes = [int(i) for i in usable_indexes[:experience_count]]
elif consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
usable_indexes) % experience_count == 0:
sampled_indexes = self.batch_seqlen_balance_sampler(
consumer, usable_indexes, experience_count, get_n_samples=False
@ -558,12 +638,19 @@ class GRPOTransferDock(TransferDock):
dim=0,
)
if not self.enable_partial_rollout:
usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0]
elif self.enable_partial_rollout:
group_states = self.global_ready_mask.view(self.prompts_num, self.n_samples_per_prompt)
update_ready_group_indexes = group_states.sum(dim=1) == self.n_samples_per_prompt
usable_indexes = (not_consumed_indexes & data_ready_indexes & update_ready_group_indexes).nonzero(as_tuple=True)[0]
if len(usable_indexes) < experience_count_n_samples:
return None
if consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
if self.enable_partial_rollout:
sampled_indexes_n_sample = [int(i) for i in usable_indexes[:experience_count_n_samples]]
elif consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
usable_indexes) % experience_count_n_samples == 0:
sampled_indexes_n_sample = self.batch_seqlen_balance_sampler(
consumer, usable_indexes, experience_count_n_samples, get_n_samples=True
@ -593,12 +680,6 @@ class GRPOTransferDock(TransferDock):
return sampled_indexes
def print_consumer_status(self, consumer: str, td_type: str):
if consumer == 'actor_train':
logger.info(f"td_type={td_type},consumer status ={self.experience_consumer_status[consumer]}")
def all_consumed(self, consumer: str):
"""If consumer has consumed all data in GRPOTransferDock.
@ -608,21 +689,131 @@ class GRPOTransferDock(TransferDock):
Returns: True or False.
"""
if self.enable_partial_rollout:
if self.GBS_train == 0:
raise ValueError("GBS for update must be provided when enabling partial rollout")
if consumer == 'actor_rollout':
all_consumed_group_num, global_ready_mask, _ = self.find_all_consumed_n_samples_groups(consumer='actor_rollout')
self.stop_partial_rollout_signal = all_consumed_group_num >= self.GBS_train
self.global_ready_mask = global_ready_mask
return all_consumed_group_num >= self.GBS_train
else:
return self.experience_consumer_status[consumer].sum() == self.GBS_train * self.n_samples_per_prompt
else:
return self.experience_consumer_status[consumer].sum() == self.max_len
def find_all_consumed_n_samples_groups(self, consumer: str):
if consumer != 'actor_rollout':
raise ValueError(f"Consumer {consumer} is not supported for partial rollout stop signal.")
num_groups = self.max_len // self.n_samples_per_prompt # 即self.prompts_num
all_consumed_status = self.rollout_completed
group_states = all_consumed_status[:num_groups * self.n_samples_per_prompt].view(num_groups,
self.n_samples_per_prompt)
all_consumed_groups_mask = (group_states == 1).all(dim=1)
global_mask = torch.zeros(self.max_len, dtype=torch.int32)
all_consumed_group_start_indices = []
for group_idx in range(num_groups):
start_idx = group_idx * self.n_samples_per_prompt
end_idx = (group_idx + 1) * self.n_samples_per_prompt
if all_consumed_groups_mask[group_idx]:
all_consumed_group_start_indices.append(start_idx)
global_mask[start_idx:end_idx] = 1
all_consumed_group_count = len(all_consumed_group_start_indices)
return all_consumed_group_count, global_mask, all_consumed_group_start_indices # all_consumed_indices
def get_update_ready(self, require_max_age_all_finished=True):
all_consumed_group_num, global_ready_mask, _ = self.find_all_consumed_n_samples_groups(consumer='actor_rollout')
self.stop_partial_rollout_signal = all_consumed_group_num >= self.GBS_train
self.global_ready_mask = global_ready_mask
if require_max_age_all_finished:
max_age_index = (self.age == self.max_age - 1).nonzero(as_tuple=True)[0]
self.max_age_all_finished = self.rollout_completed[max_age_index].sum().item() == len(max_age_index)
return (self.stop_partial_rollout_signal and self.max_age_all_finished)
else:
return self.stop_partial_rollout_signal
def sort_every_n_samples_by_age(self):
group_indices = torch.arange(0, self.max_len,
self.n_samples_per_prompt) # n=8, this should be [0, 8, 16]
group_ages = []
for i in group_indices:
group_ages.append(self.age[i:i + self.n_samples_per_prompt].max())
self.age[i:i + self.n_samples_per_prompt] = group_ages[-1]
# 按照age对group排序
sorted_group_idx = sorted(range(len(group_ages)), key=group_ages.__getitem__, reverse=True)
# 构建全局index的映射
global_indices = []
for group_idx in sorted_group_idx:
start_idx = group_idx * self.n_samples_per_prompt
end_idx = start_idx + self.n_samples_per_prompt
group_range = torch.arange(start_idx, end_idx)
global_indices.append(group_range)
# 拼接所有index
global_indices = torch.cat(global_indices)
# 对experience进行重排序 dict of list of tensors
new_experience_data = {}
for key, col_list in self.experience_data.items():
new_col_list = [col_list[i] for i in global_indices]
new_experience_data[key] = new_col_list
self.experience_data = new_experience_data
# 对status dicts进行重排序
new_experience_data_status = {}
new_experience_consumer_status = {}
for key, value in self.experience_data_status.items():
new_experience_data_status[key] = value[global_indices]
for key, value in self.experience_consumer_status.items():
new_experience_consumer_status[key] = value[global_indices]
self.experience_data_status = new_experience_data_status
self.experience_consumer_status = new_experience_consumer_status
# 对age, rollout_completed进行重排序
self.age = self.age[global_indices]
self.age[self.age == -1] = 0
self.rollout_completed = self.rollout_completed[global_indices]
def clear(self):
"""Reset consumer status.Clear data and data status in GRPOTransferDock.
Returns: None
"""
if self.enable_partial_rollout:
all_consumed_indexes = (self.experience_consumer_status["actor_train"] == 1).nonzero(as_tuple=True)[0]
# 第一轮不需要sort和clear
if all_consumed_indexes.numel() > 0:
for key in self.experience_consumer_status:
self.experience_consumer_status[key][all_consumed_indexes] = 0
self._clear_experience_data_and_status(indexes=all_consumed_indexes)
self.age = self.age + (self.experience_data_status['input_ids'] == 1).to(torch.int32)
self.age[all_consumed_indexes] = -1
self.global_ready_mask[all_consumed_indexes] = 0
self.rollout_completed[all_consumed_indexes] = 0
self.sort_every_n_samples_by_age()
self.stop_partial_rollout_signal = False
else:
self.experience_consumer_status = {
key: torch.zeros(self.max_len, dtype=torch.int32)
for key in self.experience_consumers
}
self.metrics.reset()
self._clear_experience_data_and_status()
self.cur_index = 0
self.metrics.reset()
def get_consumer_status(self):
"""Get consumer status.
@ -680,6 +871,11 @@ class GRPOTransferDock(TransferDock):
return sampled_indexes
def get_incomplete_response_num(self):
incomplete_response_num = self.experience_data_status['prompts'].sum() - self.rollout_completed.sum()
return incomplete_response_num
def pad_experience(
experience_batch: Dict[str, List[Tensor]],
pad_id: int,
@ -741,7 +937,7 @@ def pad_experience(
raise ValueError("ERROR: when pad, get an empty experience_batch")
else:
for experience_column, experience in experience_batch.items():
if experience_column in ["prompt_length", "response_length"]:
if experience_column in ["prompt_length", "response_length", "age"]:
padded = torch.cat(experience).reshape(-1, 1)
elif experience_column in ["position_ids"]:
padded = pad_sequence(experience, batch_first=True, padding_value=pad_id)
@ -811,7 +1007,7 @@ def trans_input_to_experience(experience_dict: Dict[str, Union[Tensor, List[Tens
return experience_columns, experience_list
def pack_experience_columns(experience_dict, experience_count):
def pack_experience_columns(experience_consumer_stage, experience_dict, experience_count, enable_partial_rollout=False):
"""
Compress experiences by packing tensors into ONE.
from experience_dict
@ -843,10 +1039,15 @@ def pack_experience_columns(experience_dict, experience_count):
batch_data = {}
batch_data_length = {}
if enable_partial_rollout and experience_consumer_stage == 'actor_rollout':
value = experience_dict['prompts']
experience_count = len(value)
else:
for key, value in experience_dict.items():
if len(value) != experience_count:
raise ValueError(f"ERROR: when pack, experience '{key}' number does not match experience_count")
raise ValueError(f"ERROR: when pack, experience '{key}' number={len(value)} does not match {experience_count=}")
for key, value in experience_dict.items():
# 判断是一维张量还是二维张量
is_2d = len(value[0].shape) > 1
if is_2d:
@ -922,7 +1123,7 @@ def unpack_pad_experience(batch_data, batch_data_length, pad_id, multiple):
padded_batch_data = {}
for key, length_list in batch_data_length.items():
if key in ['prompt_length', 'response_length']:
if key in ['prompt_length', 'response_length', 'age']:
padded_batch_data[key] = batch_data[key].view(-1, 1)
continue
data = batch_data[key]
@ -994,7 +1195,7 @@ def unpack_pad_experience(batch_data, batch_data_length, pad_id, multiple):
def put_prompts_experience(
batch: Dict[str, torch.Tensor], n_samples_per_prompt, dataset_additional_keys: List[str] = None, indexes=None,
batch: Dict[str, torch.Tensor], n_samples_per_prompt, dataset_additional_keys: List[str] = None, indexes=None, add_another_batch=False,
):
"""Put data into specified columns and rows.
@ -1027,7 +1228,10 @@ def put_prompts_experience(
for _ in range(n_samples_per_prompt):
values.append(value)
add_vals[add_keys] = values
if indexes is None:
prompt_nums = len(prompt_length)
if add_another_batch:
indexes = [prompt_nums + i for i in range(prompt_nums)]
elif indexes is None:
indexes = [i for i in range(len(prompt_length))]
data_dict = dict(

View File

@ -150,6 +150,9 @@ def truncate_rows(tensor, index_tensor, left_pad=False):
truncated_tensors = []
for i in range(mbs):
if index_tensor[i].item() == 0 and tensor[i, 0].item() == -1:
truncated_row = torch.tensor([], dtype=torch.int32).cpu()
else:
# 获取当前行的截断索引
trunc_idx = index_tensor[i].item()
# 截断当前行

View File

@ -8,13 +8,12 @@ from enum import Enum
from typing import Callable
import logging as logger
import numpy as np
import ray
from torch import nn
import torch
from transformers import AutoConfig
from mindspeed_rl.config_cls.megatron_config import MegatronConfig
from mindspeed_rl.utils.optimizer_module import OptimizerConfig
from mindspeed_rl.config_cls.rl_config import RLConfig
from mindspeed_rl.config_cls.generate_config import GenerateConfig
from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig, MsprobeConfig
@ -85,6 +84,7 @@ class ActorHybridWorkerBase(BaseWorker):
self.actor_profiler = None
self.prof_iteration = 1
self.idx = 0
self.enable_partial_rollout = self.rl_config.partial_rollout_max_split > 1
def initialize(self):
self.setup_distributed_rank()
@ -234,16 +234,15 @@ class ActorHybridWorkerBase(BaseWorker):
start_time_defined = False
while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0:
batch_data, index = self.dispatch_transfer_dock_data(
experience_consumer_stage,
batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage,
experience_columns,
experience_count,
self.megatron_config.tensor_model_parallel_size,
self.megatron_config.context_parallel_size,
self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
get_n_samples=False
)
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=self.enable_partial_rollout)
if not start_time_defined:
start_time = time.time()
start_time_defined = True
@ -287,6 +286,11 @@ class ActorHybridWorkerBase(BaseWorker):
self.num_floating_point_operations_so_far)
self.sharding_manager.exit_train_mode()
def get_partial_rollout_stop_signal(self):
if not self.enable_partial_rollout:
return False
return ray.get(self.td.get_update_ready.remote(require_max_age_all_finished=self.rl_config.require_max_age_all_finished))
@mstx_timer_decorator
def generate_sequences(self):
sharding_infer_interval = 0
@ -299,7 +303,13 @@ class ActorHybridWorkerBase(BaseWorker):
experience_columns = ['prompts', 'prompt_length']
if is_multimodal():
experience_columns.extend(['input_ids', 'input_ids_length'])
if self.enable_partial_rollout:
experience_columns.extend(['responses', 'response_length', 'age'])
if self.enable_partial_rollout and (self.rl_config.async_engine or self.iteration == self.megatron_config.train_iters - 1):
incomplete_resp_num = ray.get(self.td.get_incomplete_response_num.remote())
experience_count = int(np.ceil(incomplete_resp_num / self.generate_config.data_parallel_size))
else:
experience_count = self.rl_config.actor_rollout_dispatch_size
pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod
@ -310,7 +320,8 @@ class ActorHybridWorkerBase(BaseWorker):
profiler_iteration=self.prof_iteration)
MsProbe.debugger_start(self.inference_model.model, tag='actor_generate_sequences')
start_time_defined = False
start_time = time.time()
while self.all_consumed(experience_consumer_stage, sorted_indexes, use_vllm=True) > 0:
batch_data, index = self.dispatch_transfer_dock_data(
experience_consumer_stage,
@ -320,25 +331,28 @@ class ActorHybridWorkerBase(BaseWorker):
cp_size=self.megatron_config.context_parallel_size,
cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
use_vllm=True
use_vllm=True,
get_n_samples=not self.enable_partial_rollout,
enable_partial_rollout=self.enable_partial_rollout
)
if not start_time_defined:
start_time = time.time()
start_time_defined = True
if batch_data and index:
gc.collect()
torch.cuda.empty_cache()
if self.rl_config.async_engine:
logger.info(f"do async generate process.")
prompts_data = batch_data['prompts']
prompt_length_data = batch_data['prompt_length']
prompts = truncate_rows(prompts_data, prompt_length_data)
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
self.async_generate_process(experience_count, index, pad_token_id, prompts_list, start_time)
self.async_generate_process(batch_data, index, pad_token_id)
else:
self.sync_generate_process(batch_data, experience_count, index, pad_token_id, start_time)
self.sync_generate_process(batch_data, experience_count, index, pad_token_id)
if self.enable_partial_rollout:
torch.distributed.barrier()
end_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/rollout",
value=[round(end_time, 4), round(start_time, 4)],
cumulate=True
)
)
profiler_step(actor_generate_profiler)
MsProbe.debugger_stop('actor_generate_sequences')
@ -358,29 +372,58 @@ class ActorHybridWorkerBase(BaseWorker):
)
logger.info("finish generate_sequences")
def sync_generate_process(self, batch_data, experience_count, index, pad_token_id, start_time):
def sync_generate_process(self, batch_data, experience_count, index, pad_token_id):
if not self.enable_partial_rollout:
indexes = list(range(0, experience_count, self.rl_config.n_samples_per_prompt))
prompts_data = batch_data['prompts'][indexes]
prompt_length_data = batch_data['prompt_length'][indexes]
# preprocess, remove padding
prompts = truncate_rows(prompts_data, prompt_length_data)
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
with replace_torch_compile():
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), indexes,
n_samples_per_prompt=self.rl_config.n_samples_per_prompt,
async_engine=self.rl_config.async_engine,
else:
prompts_data = batch_data['prompts']
prompt_length_data = batch_data['prompt_length']
responses = batch_data['responses']
responses_length_partial = batch_data['response_length']
responses_partial = truncate_rows(responses, responses_length_partial)
prompts = truncate_rows(prompts_data, prompt_length_data)
prompts_for_vllm = [torch.cat(
(prompt, response), dim=0) for prompt, response in
zip(prompts, responses_partial)]
prompts_list = [prompt.numpy().tolist() for prompt in prompts_for_vllm]
if self.enable_partial_rollout:
max_tokens = self.generate_config.sampling_config["max_tokens"] // self.rl_config.partial_rollout_max_split
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list),
max_tokens=max_tokens, n=1,
extra_info=batch_data)
else:
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list),
extra_info=batch_data)
responses = remove_padding_and_split_to_list(responses_pad_right, self.tokenizer.eod, pad_token_id)
responses_length = [torch.tensor([len(response)]) for response in responses]
if is_multimodal():
prompts_data = batch_data['input_ids'][indexes].cpu().unbind()
else:
prompts_data = prompts
if self.enable_partial_rollout:
new_responses = []
for response_partial, response in zip(responses_partial, responses):
new_resp = torch.cat((response_partial, response), dim=0)
test_resp = new_resp >= self.tokenizer.vocab_size
if test_resp.sum() > 0:
new_resp[test_resp] = 0
new_responses.append(new_resp)
responses = new_responses
else:
prompts = []
for prompt in prompts_data:
for _ in range(self.rl_config.n_samples_per_prompt):
prompts.append(copy.deepcopy(prompt))
responses_length = [torch.tensor([len(response)]) for response in responses]
input_ids_list = []
for prompt, response in zip(prompts, responses):
input_ids_list.append(torch.cat((prompt, response), dim=0))
@ -391,32 +434,74 @@ class ActorHybridWorkerBase(BaseWorker):
}
if is_multimodal():
outputs['prompt_length'] = batch_data['input_ids_length']
self.collect_transfer_dock_data(outputs, index, use_vllm=True)
end_time = time.time()
MsProbe.save_data({"responses": responses, "prompts": prompts})
ray.get(
self.td.update_metrics.remote(
"timing/rollout",
value=[round(end_time, 4), round(start_time, 4)],
cumulate=True
)
)
def async_generate_process(self, experience_count, index, pad_token_id, prompts_list, start_time):
# inference
if self.enable_partial_rollout:
finish_status = [torch.tensor([0])] * len(responses_length)
for idx, _ in enumerate(responses):
if responses[idx][-1] == self.tokenizer.eod or \
(prompt_length_data[idx][0] + responses_length[
idx][0] >= self.generate_config.max_model_len) or responses_length[
idx][0] >= self.generate_config.sampling_config["max_tokens"]:
finish_status[idx] = torch.tensor([1])
outputs["rollout_completed"] = finish_status
self.collect_transfer_dock_data(outputs, index, use_vllm=True)
MsProbe.save_data({"responses": responses, "prompts": prompts})
def async_generate_process(self, batch_data, index, pad_token_id):
self.actor_hybrid.inference_actor.init_cache_engine()
with replace_torch_compile():
prompts_data = batch_data['prompts']
prompt_length_data = batch_data['prompt_length']
prompts = truncate_rows(prompts_data, prompt_length_data)
if self.enable_partial_rollout:
responses = batch_data['responses']
responses_length_partial = batch_data['response_length']
responses_partial = truncate_rows(responses, responses_length_partial)
prompts_for_vllm = [torch.cat((prompt, response), dim=0) for prompt, response in zip(prompts, responses_partial)]
prompts_list = [prompt.numpy().tolist() for prompt in prompts_for_vllm]
else:
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
if self.enable_partial_rollout:
response_generator = self.actor_hybrid.generate_sequences(
copy.deepcopy(prompts_list),
indexes=index,
n=1,
async_engine=True,
stop_singal_func=self.get_partial_rollout_stop_signal,
)
else:
response_generator = self.actor_hybrid.generate_sequences(
copy.deepcopy(prompts_list),
indexes=index,
max_tokens=self.generate_config.sampling_config["max_tokens"],
n_samples_per_prompt=1,
n=1,
async_engine=True,
)
for samples, idx in response_generator:
for samples, idx_output in response_generator:
prompts, responses, log_probs = samples
responses = remove_padding_and_split_to_list(responses, self.tokenizer.eod, pad_token_id)
remove_input_ids = False
if self.enable_partial_rollout and len(responses[0]) == 1:
remove_input_ids = True
if self.enable_partial_rollout:
responses_partial_new = []
prompt_length_new = []
for idx in range(len(responses)):
iidx = index.index(idx_output[idx])
responses_partial_new.append(responses_partial[iidx])
prompt_length_new.append(prompt_length_data[iidx])
new_responses = []
for response_partial, response in zip(responses_partial_new, responses):
new_resp = torch.cat((response_partial, response), dim=0)
test_resp = new_resp >= self.tokenizer.vocab_size
if test_resp.sum() > 0:
new_resp[test_resp] = 0
new_responses.append(new_resp)
responses = new_responses
responses_length = [torch.tensor([len(response)]) for response in responses]
input_ids_list = []
@ -428,16 +513,20 @@ class ActorHybridWorkerBase(BaseWorker):
'input_ids': input_ids_list,
'response_length': responses_length
}
self.collect_transfer_dock_data(outputs, idx, use_vllm=True)
if remove_input_ids:
outputs.pop("input_ids")
end_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/rollout",
value=[round(end_time, 4), round(start_time, 4)],
cumulate=True
)
)
if self.enable_partial_rollout:
finish_status = [torch.tensor([0])] * len(responses_length)
for idx, _ in enumerate(responses):
if responses[idx][-1] == self.tokenizer.eod or \
prompt_length_new[idx][0].to('cpu') + responses_length[
idx][0] >= self.generate_config.max_model_len or responses_length[
idx][0] >= self.generate_config.sampling_config["max_tokens"]:
finish_status[idx] = torch.tensor([1])
outputs["rollout_completed"] = finish_status
self.collect_transfer_dock_data(outputs, idx_output, use_vllm=True)
MsProbe.save_data({"responses": responses, "prompts": prompts})
self.actor_hybrid.inference_actor.free_cache_engine()
@mstx_timer_decorator
@ -466,16 +555,15 @@ class ActorHybridWorkerBase(BaseWorker):
start_time_defined = False
while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0:
batch_data, index = self.dispatch_transfer_dock_data(
experience_consumer_stage,
batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage,
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
cp_size=self.megatron_config.context_parallel_size,
cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
get_n_samples=False
)
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=self.enable_partial_rollout)
if not start_time_defined:
start_time = time.time()
start_time_defined = True

View File

@ -134,7 +134,6 @@ class BaseWorker(BaseRayWorker, ABC):
current_device = next(self.model[0].parameters()).device
status = torch.tensor(0, device=current_device)
rank_flg = False
if not use_vllm:
rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and
get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and
@ -225,7 +224,7 @@ class BaseWorker(BaseRayWorker, ABC):
def dispatch_transfer_dock_data(self, experience_consumer_stage,
experience_columns, experience_count, tp_size=1, cp_size=1, cp_algo=None,
use_vllm=False, indexes=None,
get_n_samples=True):
get_n_samples=True, enable_partial_rollout=False):
pad_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod
if is_multimodal():
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
@ -254,6 +253,12 @@ class BaseWorker(BaseRayWorker, ABC):
experience_count, indexes=indexes,
get_n_samples=get_n_samples,
use_batch_seqlen_balance=self.rl_config.use_dp_batch_balance)) # cpu数据
elif enable_partial_rollout:
# 获取单条数据,不满足的位置补重复样本
dp_world_size = self.parallel_state.get_data_parallel_world_size()
batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns,
experience_count, dp_world_size, indexes=indexes,
get_n_samples=get_n_samples)) # cpu数据
else:
batch_data, index = ray.get(
self.td.get_experience.remote(experience_consumer_stage, experience_columns,
@ -290,7 +295,10 @@ class BaseWorker(BaseRayWorker, ABC):
return None, None
if rank_flg:
batch_data, batch_data_length = pack_experience_columns(batch_data, experience_count)
batch_data, batch_data_length = pack_experience_columns(experience_consumer_stage, batch_data,
experience_count,
enable_partial_rollout=enable_partial_rollout,
)
for key in experience_columns:
if rank_flg:

View File

@ -118,7 +118,7 @@ class ReferenceWorkerBase(BaseWorker):
cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=False)
get_n_samples=self.rl_config.partial_rollout_max_split > 1)
if not start_time_defined:
start_time = time.time()

View File

@ -108,6 +108,7 @@ class RewardWorkerBase(BaseWorker):
cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=self.rl_config.partial_rollout_max_split > 1
)
if not start_time_defined:
start_time = time.time()

View File

@ -1,6 +1,5 @@
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
import ray
import torch
from transformers import AutoTokenizer
import torch
@ -46,7 +45,8 @@ class RuleReward(object):
experience_consumer_stage,
experience_columns,
experience_count,
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
get_n_samples=True
)
) # cpu数据
batch_data = remove_padding_tensor_dict_to_dict(batch_data)

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 221 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 698 KiB