mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
16
README.md
16
README.md
@ -4,7 +4,14 @@
|
||||
|
||||
MindSpeed RL是基于昇腾生态的强化学习加速框架,旨在为华为 [昇腾芯片](https://www.hiascend.com/) 生态合作伙伴提供端到端的RL训推解决方案,支持超大昇腾集群训推共卡/分离部署、多模型异步流水调度、训推异构切分通信等核心加速能力
|
||||
|
||||
---
|
||||
|
||||
## NEWS
|
||||
|
||||
|
||||
🚀🚀🚀 Partial Rollout功能已支持!!🚀🚀🚀
|
||||
|
||||
---
|
||||
|
||||
## 安装指南
|
||||
|
||||
@ -195,6 +202,15 @@ MindSpeed RL是基于昇腾生态的强化学习加速框架,旨在为华为 [
|
||||
</td>
|
||||
<td> Preview</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Partial Rollout</td>
|
||||
<td><a href="docs/features/partial_rollout.md">Doc</a></td>
|
||||
<td rowspan="1">
|
||||
GRPO <br>
|
||||
</td>
|
||||
</td>
|
||||
<td> Preview</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
|
||||
</table>
|
||||
|
45
docs/features/partial_rollout.md
Normal file
45
docs/features/partial_rollout.md
Normal file
@ -0,0 +1,45 @@
|
||||
# partial rollout
|
||||
|
||||
## 简介
|
||||
Partial rollout 核心思想是通过对长序列 response 推理样本做提前中断,并在下次推理过程中对当前样本进行续推,从而避免单一的长尾样本对推理过程造成资源浪费。通过该能力,我们可以降低长序列推理场景下的长尾样本对端到端性能的影响。
|
||||
|
||||
## 使用方法
|
||||
```yaml
|
||||
rl_config:
|
||||
partial_rollout_max_split: N # 设置N>1即可在N轮内完全推理完成最长序列
|
||||
```
|
||||
## 技术方案
|
||||
### 同步推理引擎方案
|
||||
|
||||
核心理念:断点续推+跨迭代长尾调度避免推理资源闲置
|
||||
同步引擎:数据按批处理,同时进入推理引擎、批次内所有数据完成推理后同时返回结果
|
||||
关键技术点:
|
||||
1. 长序列推理截断机制:根据最大推理长度和次数设置推理截断点,将截断样本放入TransferDock,当满足≥GBS个prompt已完成全部推理,则进入后续计算任务,否则则从TransferDock中取数据再次推理,达成高资源利用率。
|
||||
2. 基于优先级的混合A数据重排和采样技术:在下一轮推理时,优先取出被截断样本进行推理,避免影响效果和收敛性。
|
||||
|
||||

|
||||
|
||||
图1 同步引擎方案示意图
|
||||
|
||||

|
||||
|
||||
图2 同步引擎流程图
|
||||
|
||||
### 异步推理引擎方案
|
||||
|
||||
核心理念:断点续推+跨迭代长尾调度避免推理资源闲置
|
||||
异步引擎:数据按批次进入推理引擎,可异步按样本粒度返回结果
|
||||
关键技术点:
|
||||
1. 实时长序列推理截断机制:实现与推理引擎交互,动态确定长尾序列被截断长度,当满足≥GBS个prompt已完成全部推理,则中断推理过程,将截断样本放入TransferDock,避免长尾序列推理拖慢整体推理时间、造成资源空置。
|
||||
2. 基于优先级的混合数据重排和采样技术:在下一轮推理时,优先取出被截断样本并混合新样本进行推理。
|
||||
3. 收敛性和稳定性保证:实现样本在规定的迭代轮数内完成推理。
|
||||

|
||||
图3 异步引擎方案示意图
|
||||
|
||||

|
||||
|
||||
图4 异步引擎流程图
|
||||
|
||||
## 验证情况
|
||||

|
||||
图5 同步引擎验证结果
|
@ -57,6 +57,8 @@ class RLConfig(BaseConfig):
|
||||
use_dp_batch_balance: Whether to use dynamic batch size balancing across data parallel ranks (default: False)
|
||||
# Default values can still be defined if no config is provided
|
||||
use_remove_padding: Whether to use packed sequences for forward (default: False)
|
||||
partial_rollout_max_split: The multiple of token splitting for max tokens when partial rollout is enabled. (default: 1)
|
||||
require_max_age_all_finished: wherther to require the reponses that have reached max_age must be completed in this iteration or can be incomplete (default: True)
|
||||
'''
|
||||
|
||||
def __init__(self, config_dict):
|
||||
@ -150,6 +152,10 @@ class RLConfig(BaseConfig):
|
||||
self.filter_groups_max_batches = 1
|
||||
self.filter_groups_train_batch_size = 1
|
||||
|
||||
self.partial_rollout_max_split = 1
|
||||
self.require_max_age_all_finished = True
|
||||
|
||||
|
||||
if config_dict.get("actor_resource") is not None:
|
||||
for key, _ in config_dict["actor_resource"].items():
|
||||
if key not in self.actor_resource:
|
||||
|
@ -2,16 +2,12 @@ from typing import Dict, Callable, Optional, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mindspeed_rl.datasets.utils import _infer_seqlen, get_prompt_index
|
||||
|
||||
from mindspeed_rl.datasets.indexed_dataset import get_packed_indexed_dataset
|
||||
from mindspeed_rl.datasets.base_dataset import BaseDataset
|
||||
from mindspeed_rl.datasets.templates import get_model_template
|
||||
from mindspeed_rl.datasets.utils import _build_index_mappings
|
||||
from mindspeed_rl.datasets.data_samplers import PromptSampler
|
||||
|
||||
|
||||
class PromptDataset(BaseDataset):
|
||||
|
@ -84,21 +84,18 @@ class ActorRolloutHybrid(ABC):
|
||||
self,
|
||||
prompts_list: List[List[int]],
|
||||
indexes=None,
|
||||
n_samples_per_prompt=None,
|
||||
async_engine=False,
|
||||
max_tokens=128,
|
||||
stop_singal_func=None,
|
||||
**kwargs) -> Tensor:
|
||||
if async_engine:
|
||||
res = self.inference_actor.async_generate_sequences(
|
||||
prompts_list,
|
||||
indexes,
|
||||
n_samples_per_prompt=n_samples_per_prompt,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
stop_singal_func=stop_singal_func,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
res = self.inference_actor.generate_sequences(prompts_list, **kwargs)[0]
|
||||
|
||||
return res
|
||||
|
||||
@mstx_timer_decorator
|
||||
|
@ -330,7 +330,8 @@ class VLLMInferEngine(BaseInferEngine):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def async_generate_sequences(self, idx_list, indexes, n_samples_per_prompt=None, **kwargs):
|
||||
def async_generate_sequences(self, idx_list, indexes, stop_singal_func=None, **kwargs):
|
||||
STOP_SIGNAL = None
|
||||
with self.update_sampling_params(**kwargs):
|
||||
for i, prompt_token_ids in enumerate(idx_list):
|
||||
request_id = f"req_{indexes[i]}_{uuid.uuid4().hex[:6]}"
|
||||
@ -340,17 +341,23 @@ class VLLMInferEngine(BaseInferEngine):
|
||||
prompt={"prompt_token_ids": prompt_token_ids},
|
||||
params=self.sampling_params
|
||||
)
|
||||
|
||||
count = 0
|
||||
while self.engine.has_unfinished_requests():
|
||||
count += 1
|
||||
if stop_singal_func is not None and count % 20 == 0:
|
||||
STOP_SIGNAL = stop_singal_func()
|
||||
|
||||
step_outputs = self.engine.step()
|
||||
for output in step_outputs:
|
||||
if output.finished:
|
||||
if output.finished or STOP_SIGNAL:
|
||||
request_id = output.request_id
|
||||
index = int(request_id.split("_")[1])
|
||||
prompt_ids = [torch.tensor(idx_list[indexes.index(index)])]
|
||||
index = [index]
|
||||
response_ids = self._post_process_outputs([output])
|
||||
yield (prompt_ids, *response_ids), index
|
||||
if STOP_SIGNAL:
|
||||
self.engine.abort_request([request_id])
|
||||
|
||||
|
||||
def _post_process_outputs(self, request_outputs):
|
||||
|
@ -42,6 +42,7 @@ class RayBaseTrainer(object):
|
||||
guarantee_order: bool = False,
|
||||
use_dp_batch_balance: bool = False,
|
||||
num_cpus_for_local_task: float = 0.1,
|
||||
partial_rollout_max_split: int = 1,
|
||||
use_kl_in_reward: bool = False,
|
||||
**kwargs):
|
||||
|
||||
@ -70,6 +71,7 @@ class RayBaseTrainer(object):
|
||||
self.guarantee_order = guarantee_order
|
||||
self.use_dp_batch_balance = use_dp_batch_balance
|
||||
self.num_cpus_for_local_task = num_cpus_for_local_task
|
||||
self.partial_rollout_max_split = partial_rollout_max_split
|
||||
self.use_kl_in_reward = use_kl_in_reward
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
@ -44,7 +44,6 @@ class RayGRPOTrainer(RayBaseTrainer):
|
||||
num_cpus_for_local_task: int = 1 Number of CPUs for local ray task.
|
||||
**kwargs: Additional parameters for base class argument passing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor_worker: RayActorGroup,
|
||||
@ -65,6 +64,7 @@ class RayGRPOTrainer(RayBaseTrainer):
|
||||
blocking: bool = False,
|
||||
guarantee_order: bool = False,
|
||||
num_cpus_for_local_task: int = 1,
|
||||
partial_rollout_max_split: int = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@ -86,21 +86,29 @@ class RayGRPOTrainer(RayBaseTrainer):
|
||||
blocking=blocking,
|
||||
guarantee_order=guarantee_order,
|
||||
num_cpus_for_local_task=num_cpus_for_local_task,
|
||||
partial_rollout_max_split=partial_rollout_max_split,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.transfer_dock = None
|
||||
self.mm_transfer_dock = None
|
||||
self.enable_partial_rollout = self.partial_rollout_max_split > 1
|
||||
self.metrics = Metric()
|
||||
if self.enable_partial_rollout:
|
||||
self.td_max_len = self.global_batch_size * 2
|
||||
else:
|
||||
self.td_max_len = self.global_batch_size
|
||||
self.transfer_dock_init()
|
||||
self.kwargs = kwargs
|
||||
self.set_actor_log_prob_skip_flag()
|
||||
|
||||
def transfer_dock_init(self):
|
||||
self.transfer_dock = GRPOTransferDock.remote(
|
||||
self.global_batch_size,
|
||||
self.n_samples_per_prompt,
|
||||
self.metrics,
|
||||
prompts_num=self.td_max_len, # max sample num
|
||||
n_samples_per_prompt=self.n_samples_per_prompt,
|
||||
metrics=self.metrics,
|
||||
max_age=self.partial_rollout_max_split,
|
||||
GBS_train=self.global_batch_size, # GBS_train
|
||||
addition_columns=self.dataset_additional_keys
|
||||
)
|
||||
if is_multimodal():
|
||||
@ -135,20 +143,30 @@ class RayGRPOTrainer(RayBaseTrainer):
|
||||
logger.info('sync start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
|
||||
else:
|
||||
logger.info('async start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters))
|
||||
if self.enable_partial_rollout:
|
||||
first_batch = next(data_iters)
|
||||
batch, indexes = put_prompts_experience(first_batch, self.n_samples_per_prompt,
|
||||
self.dataset_additional_keys)
|
||||
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
|
||||
logger.info(f'training start, put first batch')
|
||||
|
||||
while iteration < self.train_iters:
|
||||
ray.get(self.transfer_dock.clear.remote())
|
||||
|
||||
batch = next(data_iters)
|
||||
batch_dict, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys)
|
||||
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch_dict, indexes=indexes))
|
||||
|
||||
if is_multimodal():
|
||||
ray.get(self.mm_transfer_dock.clear.remote())
|
||||
ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
|
||||
|
||||
last_iter = iteration == self.train_iters - 1
|
||||
with Timer(name='iteration', logger=None) as all_timer:
|
||||
# generate sequences
|
||||
batch = next(data_iters)
|
||||
if self.enable_partial_rollout:
|
||||
if not last_iter: # and batch is not None: # None?
|
||||
batch, indexes = put_prompts_experience(batch, self.n_samples_per_prompt,
|
||||
self.dataset_additional_keys,
|
||||
add_another_batch=True)
|
||||
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch, indexes=indexes, is_prompt=True))
|
||||
else:
|
||||
batch_dict, indexes = put_prompts_experience(batch, self.n_samples_per_prompt, self.dataset_additional_keys)
|
||||
ray.get(self.transfer_dock.put_experience.remote(data_dict=batch_dict, indexes=indexes, is_prompt=True))
|
||||
if is_multimodal():
|
||||
ray.get(self.mm_transfer_dock.clear.remote())
|
||||
ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)]))
|
||||
|
||||
self.actor_worker.generate_sequences(blocking=self.blocking)
|
||||
|
||||
# compute rm scores.
|
||||
@ -158,9 +176,10 @@ class RayGRPOTrainer(RayBaseTrainer):
|
||||
reward_worker.compute_rm_score(blocking=self.blocking)
|
||||
else:
|
||||
rule_reward.append(reward_worker.compute_rm_score.remote())
|
||||
ray.get(rule_reward)
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
self.compute_advantage(blocking=False, guarantee_order=self.guarantee_order)
|
||||
self.compute_advantage(blocking=True, guarantee_order=self.guarantee_order)
|
||||
|
||||
# compute reference log_prob
|
||||
self.ref_worker.compute_ref_log_prob(blocking=self.blocking)
|
||||
@ -203,6 +222,7 @@ class RayGRPOTrainer(RayBaseTrainer):
|
||||
metrics.update("vllm_tps", vllm_tps)
|
||||
iteration += 1
|
||||
logger.info(metrics.metric, iteration, self.train_iters)
|
||||
ray.get(self.transfer_dock.clear.remote())
|
||||
if self.tensorboard is not None:
|
||||
for k, v in metrics.metric.items():
|
||||
self.tensorboard.add_scalar(f"train/{k}", v, iteration)
|
||||
|
@ -118,9 +118,10 @@ class TransferDock(ABC):
|
||||
)
|
||||
|
||||
for column_idx, single_column in enumerate(experience_columns):
|
||||
for i, index in enumerate(indexes):
|
||||
self.experience_data[single_column][index] = experience[column_idx][i]
|
||||
self.experience_data_status[single_column][index] = 1
|
||||
for i, idx in enumerate(indexes):
|
||||
if idx >= 0:
|
||||
self.experience_data[single_column][idx] = experience[column_idx][i]
|
||||
self.experience_data_status[single_column][idx] = 1
|
||||
|
||||
def _get(self, experience_columns: List[str], indexes: List[int]):
|
||||
"""Get data based on row and column numbers.
|
||||
@ -188,7 +189,7 @@ class TransferDock(ABC):
|
||||
elapsed_time > self.timeout
|
||||
and elapsed_time % self.timeout_interval < 0.1
|
||||
):
|
||||
logger.warning(f"TIMEOUT: data_ready has slept {elapsed_time} second")
|
||||
logger.info(f"TIMEOUT: data_ready has slept {elapsed_time} second, because {single_column} not ready")
|
||||
time.sleep(0.1)
|
||||
if len(indexes) == 1:
|
||||
data_ready = self.experience_data_status[single_column][indexes] == 1
|
||||
@ -256,6 +257,8 @@ class GRPOTransferDock(TransferDock):
|
||||
prompts_num: int,
|
||||
n_samples_per_prompt: int,
|
||||
metrics=None,
|
||||
max_age: int = 1,
|
||||
GBS_train: int = 0,
|
||||
addition_columns: Union[List[str], None] = None,
|
||||
addition_consumers: Union[List[str], None] = None,
|
||||
timeout: Union[int, None] = None,
|
||||
@ -331,6 +334,16 @@ class GRPOTransferDock(TransferDock):
|
||||
key: threading.Lock()
|
||||
for key in self.experience_consumers
|
||||
}
|
||||
|
||||
self.max_age = max_age
|
||||
self.GBS_train = GBS_train
|
||||
self.rollout_completed = torch.zeros(self.max_len, dtype=torch.int32) # 标志当前样本是否完成rollout:eod || max_tokens
|
||||
self.age = torch.zeros(self.max_len, dtype=torch.int32) # 落后当前actor参数的训练步数,是否需要按age排序?age的更新需要在TD逐出和重排序的时候做
|
||||
self.enable_partial_rollout = max_age > 1 # max_age = 1 是续推0次,因为rollout_completed的判断是在TD外面做的
|
||||
if self.enable_partial_rollout:
|
||||
self.stop_partial_rollout_signal = False
|
||||
self.global_ready_mask = torch.zeros(self.max_len, dtype=torch.int32)
|
||||
|
||||
self.metrics = metrics
|
||||
self.prefetch_request_index_lock = threading.Lock()
|
||||
self.cur_index = 0
|
||||
@ -365,6 +378,7 @@ class GRPOTransferDock(TransferDock):
|
||||
consumer: str,
|
||||
experience_columns: List[str],
|
||||
experience_count: int = None,
|
||||
dp_size: int = 1,
|
||||
indexes: List[int] = None,
|
||||
get_n_samples: bool = True,
|
||||
use_batch_seqlen_balance: bool = False
|
||||
@ -392,8 +406,17 @@ class GRPOTransferDock(TransferDock):
|
||||
|
||||
for experience_column in experience_columns:
|
||||
if experience_column not in self.experience_columns:
|
||||
if experience_column != 'age':
|
||||
raise ValueError(
|
||||
f"get experience ERROR: {experience_column} not in TD experience_column {self.experience_columns}"
|
||||
)
|
||||
elif consumer == 'actor_rollout' and self.enable_partial_rollout:
|
||||
experience_columns.remove('age')
|
||||
|
||||
if consumer == "actor_rollout" and self.enable_partial_rollout:
|
||||
if get_n_samples:
|
||||
raise ValueError(
|
||||
f"get experience ERROR: {experience_column} not in TD experience_column {self.experience_columns}"
|
||||
"get_n_samples not supported for rollout when actor_rollout enables partial_rollout"
|
||||
)
|
||||
|
||||
if indexes is None:
|
||||
@ -402,7 +425,7 @@ class GRPOTransferDock(TransferDock):
|
||||
f"TD max_len: {self.max_len} need >= experience_count: {experience_count}"
|
||||
)
|
||||
|
||||
if self.max_len % experience_count != 0:
|
||||
if self.max_len % experience_count != 0 and not self.enable_partial_rollout:
|
||||
raise ValueError(
|
||||
f"TD max_len:{self.max_len} need be divisible by experience_count: {experience_count}"
|
||||
)
|
||||
@ -430,6 +453,22 @@ class GRPOTransferDock(TransferDock):
|
||||
self.experience_consumer_status[consumer][indexes] = 1
|
||||
experience = self._get(experience_columns, indexes)
|
||||
|
||||
if consumer == "actor_rollout" and self.enable_partial_rollout:
|
||||
experience_columns.append('age')
|
||||
age_list = [torch.tensor([i]) for i in self.age[indexes]]
|
||||
experience.append(age_list)
|
||||
## 状态量都在取sample时刷新
|
||||
self.experience_data_status["responses"][indexes] = 0
|
||||
self.experience_data_status["response_length"][indexes] = 0
|
||||
sample_num = len(indexes)
|
||||
if sample_num < experience_count and sample_num > 0:
|
||||
min_dp_size_multiple = ((sample_num + dp_size - 1) // dp_size) * dp_size
|
||||
indexes_extend = indexes + [-2] * (min_dp_size_multiple - sample_num)
|
||||
for col, _ in enumerate(experience):
|
||||
for _, _ in enumerate(indexes_extend[sample_num:]):
|
||||
experience[col].append(experience[col][sample_num - 1]) # 重复最后一条样本
|
||||
indexes = indexes_extend
|
||||
|
||||
experience_batch = {}
|
||||
for i, experience_column in enumerate(experience_columns):
|
||||
experience_batch[experience_column] = experience[i]
|
||||
@ -440,6 +479,7 @@ class GRPOTransferDock(TransferDock):
|
||||
self,
|
||||
data_dict: Dict[str, Union[Tensor, List[Tensor]]],
|
||||
indexes: List[int] = None,
|
||||
is_prompt: bool = False
|
||||
):
|
||||
"""Put data into specified columns and rows.
|
||||
|
||||
@ -456,9 +496,38 @@ class GRPOTransferDock(TransferDock):
|
||||
"put experience into TD without indexes, indexes must be provided"
|
||||
)
|
||||
data_dict = remove_padding_tensor_dict_to_dict(data_dict)
|
||||
|
||||
if self.enable_partial_rollout and self.GBS_train == 0:
|
||||
raise ValueError("GBS for update must be provided when enabling partial rollout")
|
||||
|
||||
if self.enable_partial_rollout and 'responses' in data_dict.keys():
|
||||
if 'rollout_completed' not in data_dict.keys():
|
||||
raise ValueError(
|
||||
"partial rollout enabled, when putting responses, rollout_completed status must be provided in data dict"
|
||||
)
|
||||
|
||||
experience_columns, experience = trans_input_to_experience(data_dict)
|
||||
|
||||
if "responses" in experience_columns: # 确定是rollout阶段
|
||||
if self.enable_partial_rollout: # 确定partial rollout功能开启
|
||||
rollout_completed_col_id = experience_columns.index('rollout_completed')
|
||||
rollout_completed_column = experience.pop(rollout_completed_col_id)
|
||||
experience_columns.pop(rollout_completed_col_id)
|
||||
for i, idx in enumerate(indexes):
|
||||
if idx >= 0:
|
||||
if rollout_completed_column[i][0] == 1:
|
||||
self.rollout_completed[idx] = 1
|
||||
|
||||
self._put(experience_columns, experience, indexes)
|
||||
# _get之后会刷新角色消费状态,所以需要再更新一下
|
||||
if ("responses" in experience_columns) and self.enable_partial_rollout:
|
||||
self.experience_consumer_status['actor_rollout'][indexes] = copy.deepcopy(self.rollout_completed[indexes])
|
||||
if self.enable_partial_rollout and is_prompt:
|
||||
self.experience_data_status['responses'][indexes] = 1
|
||||
self.experience_data_status['response_length'][indexes] = 1
|
||||
for i in indexes:
|
||||
self.experience_data['responses'][i] = torch.tensor([-1], dtype=torch.int32)
|
||||
self.experience_data['response_length'][i] = torch.tensor([0], dtype=torch.int32)
|
||||
|
||||
def _sample_ready_index(
|
||||
self,
|
||||
@ -486,15 +555,26 @@ class GRPOTransferDock(TransferDock):
|
||||
[self.experience_data_status[single_column] == 1 for single_column in experience_columns]
|
||||
), dim=0,
|
||||
)
|
||||
usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0]
|
||||
|
||||
if self.enable_partial_rollout and consumer != 'actor_rollout':
|
||||
update_ready_indexes = self.global_ready_mask == 1
|
||||
usable_indexes = (not_consumed_indexes & data_ready_indexes & update_ready_indexes).nonzero(as_tuple=True)[0]
|
||||
else:
|
||||
usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0]
|
||||
|
||||
if len(usable_indexes) < experience_count:
|
||||
return None
|
||||
if self.enable_partial_rollout and consumer == 'actor_rollout' and len(usable_indexes) > 0:
|
||||
experience_count = len(usable_indexes)
|
||||
else:
|
||||
return None
|
||||
|
||||
if experience_count <= 0:
|
||||
return None
|
||||
|
||||
if consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
|
||||
if self.enable_partial_rollout and consumer == 'actor_rollout':
|
||||
sampled_indexes = [int(i) for i in usable_indexes[:experience_count]]
|
||||
|
||||
elif consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
|
||||
usable_indexes) % experience_count == 0:
|
||||
sampled_indexes = self.batch_seqlen_balance_sampler(
|
||||
consumer, usable_indexes, experience_count, get_n_samples=False
|
||||
@ -558,12 +638,19 @@ class GRPOTransferDock(TransferDock):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0]
|
||||
if not self.enable_partial_rollout:
|
||||
usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0]
|
||||
elif self.enable_partial_rollout:
|
||||
group_states = self.global_ready_mask.view(self.prompts_num, self.n_samples_per_prompt)
|
||||
update_ready_group_indexes = group_states.sum(dim=1) == self.n_samples_per_prompt
|
||||
usable_indexes = (not_consumed_indexes & data_ready_indexes & update_ready_group_indexes).nonzero(as_tuple=True)[0]
|
||||
|
||||
if len(usable_indexes) < experience_count_n_samples:
|
||||
return None
|
||||
|
||||
if consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
|
||||
if self.enable_partial_rollout:
|
||||
sampled_indexes_n_sample = [int(i) for i in usable_indexes[:experience_count_n_samples]]
|
||||
elif consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(
|
||||
usable_indexes) % experience_count_n_samples == 0:
|
||||
sampled_indexes_n_sample = self.batch_seqlen_balance_sampler(
|
||||
consumer, usable_indexes, experience_count_n_samples, get_n_samples=True
|
||||
@ -593,12 +680,6 @@ class GRPOTransferDock(TransferDock):
|
||||
|
||||
return sampled_indexes
|
||||
|
||||
def print_consumer_status(self, consumer: str, td_type: str):
|
||||
if consumer == 'actor_train':
|
||||
logger.info(f"td_type={td_type},consumer status ={self.experience_consumer_status[consumer]}")
|
||||
|
||||
|
||||
|
||||
def all_consumed(self, consumer: str):
|
||||
"""If consumer has consumed all data in GRPOTransferDock.
|
||||
|
||||
@ -608,7 +689,100 @@ class GRPOTransferDock(TransferDock):
|
||||
Returns: True or False.
|
||||
|
||||
"""
|
||||
return self.experience_consumer_status[consumer].sum() == self.max_len
|
||||
if self.enable_partial_rollout:
|
||||
if self.GBS_train == 0:
|
||||
raise ValueError("GBS for update must be provided when enabling partial rollout")
|
||||
if consumer == 'actor_rollout':
|
||||
all_consumed_group_num, global_ready_mask, _ = self.find_all_consumed_n_samples_groups(consumer='actor_rollout')
|
||||
self.stop_partial_rollout_signal = all_consumed_group_num >= self.GBS_train
|
||||
self.global_ready_mask = global_ready_mask
|
||||
return all_consumed_group_num >= self.GBS_train
|
||||
else:
|
||||
return self.experience_consumer_status[consumer].sum() == self.GBS_train * self.n_samples_per_prompt
|
||||
else:
|
||||
return self.experience_consumer_status[consumer].sum() == self.max_len
|
||||
|
||||
def find_all_consumed_n_samples_groups(self, consumer: str):
|
||||
if consumer != 'actor_rollout':
|
||||
raise ValueError(f"Consumer {consumer} is not supported for partial rollout stop signal.")
|
||||
|
||||
num_groups = self.max_len // self.n_samples_per_prompt # 即self.prompts_num
|
||||
all_consumed_status = self.rollout_completed
|
||||
group_states = all_consumed_status[:num_groups * self.n_samples_per_prompt].view(num_groups,
|
||||
self.n_samples_per_prompt)
|
||||
all_consumed_groups_mask = (group_states == 1).all(dim=1)
|
||||
global_mask = torch.zeros(self.max_len, dtype=torch.int32)
|
||||
|
||||
all_consumed_group_start_indices = []
|
||||
|
||||
for group_idx in range(num_groups):
|
||||
start_idx = group_idx * self.n_samples_per_prompt
|
||||
end_idx = (group_idx + 1) * self.n_samples_per_prompt
|
||||
|
||||
if all_consumed_groups_mask[group_idx]:
|
||||
all_consumed_group_start_indices.append(start_idx)
|
||||
global_mask[start_idx:end_idx] = 1
|
||||
|
||||
all_consumed_group_count = len(all_consumed_group_start_indices)
|
||||
return all_consumed_group_count, global_mask, all_consumed_group_start_indices # all_consumed_indices
|
||||
|
||||
|
||||
def get_update_ready(self, require_max_age_all_finished=True):
|
||||
all_consumed_group_num, global_ready_mask, _ = self.find_all_consumed_n_samples_groups(consumer='actor_rollout')
|
||||
self.stop_partial_rollout_signal = all_consumed_group_num >= self.GBS_train
|
||||
self.global_ready_mask = global_ready_mask
|
||||
|
||||
if require_max_age_all_finished:
|
||||
max_age_index = (self.age == self.max_age - 1).nonzero(as_tuple=True)[0]
|
||||
self.max_age_all_finished = self.rollout_completed[max_age_index].sum().item() == len(max_age_index)
|
||||
return (self.stop_partial_rollout_signal and self.max_age_all_finished)
|
||||
else:
|
||||
return self.stop_partial_rollout_signal
|
||||
|
||||
def sort_every_n_samples_by_age(self):
|
||||
group_indices = torch.arange(0, self.max_len,
|
||||
self.n_samples_per_prompt) # n=8, this should be [0, 8, 16]
|
||||
group_ages = []
|
||||
for i in group_indices:
|
||||
group_ages.append(self.age[i:i + self.n_samples_per_prompt].max())
|
||||
self.age[i:i + self.n_samples_per_prompt] = group_ages[-1]
|
||||
|
||||
# 按照age对group排序
|
||||
sorted_group_idx = sorted(range(len(group_ages)), key=group_ages.__getitem__, reverse=True)
|
||||
|
||||
# 构建全局index的映射
|
||||
global_indices = []
|
||||
for group_idx in sorted_group_idx:
|
||||
start_idx = group_idx * self.n_samples_per_prompt
|
||||
end_idx = start_idx + self.n_samples_per_prompt
|
||||
group_range = torch.arange(start_idx, end_idx)
|
||||
global_indices.append(group_range)
|
||||
|
||||
# 拼接所有index
|
||||
global_indices = torch.cat(global_indices)
|
||||
|
||||
# 对experience进行重排序 dict of list of tensors
|
||||
new_experience_data = {}
|
||||
for key, col_list in self.experience_data.items():
|
||||
new_col_list = [col_list[i] for i in global_indices]
|
||||
new_experience_data[key] = new_col_list
|
||||
self.experience_data = new_experience_data
|
||||
|
||||
# 对status dicts进行重排序
|
||||
new_experience_data_status = {}
|
||||
new_experience_consumer_status = {}
|
||||
|
||||
for key, value in self.experience_data_status.items():
|
||||
new_experience_data_status[key] = value[global_indices]
|
||||
for key, value in self.experience_consumer_status.items():
|
||||
new_experience_consumer_status[key] = value[global_indices]
|
||||
self.experience_data_status = new_experience_data_status
|
||||
self.experience_consumer_status = new_experience_consumer_status
|
||||
|
||||
# 对age, rollout_completed进行重排序
|
||||
self.age = self.age[global_indices]
|
||||
self.age[self.age == -1] = 0
|
||||
self.rollout_completed = self.rollout_completed[global_indices]
|
||||
|
||||
def clear(self):
|
||||
"""Reset consumer status.Clear data and data status in GRPOTransferDock.
|
||||
@ -616,13 +790,30 @@ class GRPOTransferDock(TransferDock):
|
||||
Returns: None
|
||||
|
||||
"""
|
||||
self.experience_consumer_status = {
|
||||
key: torch.zeros(self.max_len, dtype=torch.int32)
|
||||
for key in self.experience_consumers
|
||||
}
|
||||
self.metrics.reset()
|
||||
self._clear_experience_data_and_status()
|
||||
|
||||
if self.enable_partial_rollout:
|
||||
all_consumed_indexes = (self.experience_consumer_status["actor_train"] == 1).nonzero(as_tuple=True)[0]
|
||||
# 第一轮不需要sort和clear
|
||||
if all_consumed_indexes.numel() > 0:
|
||||
for key in self.experience_consumer_status:
|
||||
self.experience_consumer_status[key][all_consumed_indexes] = 0
|
||||
self._clear_experience_data_and_status(indexes=all_consumed_indexes)
|
||||
|
||||
self.age = self.age + (self.experience_data_status['input_ids'] == 1).to(torch.int32)
|
||||
self.age[all_consumed_indexes] = -1
|
||||
self.global_ready_mask[all_consumed_indexes] = 0
|
||||
self.rollout_completed[all_consumed_indexes] = 0
|
||||
|
||||
self.sort_every_n_samples_by_age()
|
||||
self.stop_partial_rollout_signal = False
|
||||
else:
|
||||
self.experience_consumer_status = {
|
||||
key: torch.zeros(self.max_len, dtype=torch.int32)
|
||||
for key in self.experience_consumers
|
||||
}
|
||||
self._clear_experience_data_and_status()
|
||||
self.cur_index = 0
|
||||
self.metrics.reset()
|
||||
|
||||
def get_consumer_status(self):
|
||||
"""Get consumer status.
|
||||
@ -680,6 +871,11 @@ class GRPOTransferDock(TransferDock):
|
||||
return sampled_indexes
|
||||
|
||||
|
||||
def get_incomplete_response_num(self):
|
||||
incomplete_response_num = self.experience_data_status['prompts'].sum() - self.rollout_completed.sum()
|
||||
return incomplete_response_num
|
||||
|
||||
|
||||
def pad_experience(
|
||||
experience_batch: Dict[str, List[Tensor]],
|
||||
pad_id: int,
|
||||
@ -741,7 +937,7 @@ def pad_experience(
|
||||
raise ValueError("ERROR: when pad, get an empty experience_batch")
|
||||
else:
|
||||
for experience_column, experience in experience_batch.items():
|
||||
if experience_column in ["prompt_length", "response_length"]:
|
||||
if experience_column in ["prompt_length", "response_length", "age"]:
|
||||
padded = torch.cat(experience).reshape(-1, 1)
|
||||
elif experience_column in ["position_ids"]:
|
||||
padded = pad_sequence(experience, batch_first=True, padding_value=pad_id)
|
||||
@ -811,7 +1007,7 @@ def trans_input_to_experience(experience_dict: Dict[str, Union[Tensor, List[Tens
|
||||
return experience_columns, experience_list
|
||||
|
||||
|
||||
def pack_experience_columns(experience_dict, experience_count):
|
||||
def pack_experience_columns(experience_consumer_stage, experience_dict, experience_count, enable_partial_rollout=False):
|
||||
"""
|
||||
Compress experiences by packing tensors into ONE.
|
||||
from experience_dict
|
||||
@ -843,10 +1039,15 @@ def pack_experience_columns(experience_dict, experience_count):
|
||||
batch_data = {}
|
||||
batch_data_length = {}
|
||||
|
||||
for key, value in experience_dict.items():
|
||||
if len(value) != experience_count:
|
||||
raise ValueError(f"ERROR: when pack, experience '{key}' number does not match experience_count")
|
||||
if enable_partial_rollout and experience_consumer_stage == 'actor_rollout':
|
||||
value = experience_dict['prompts']
|
||||
experience_count = len(value)
|
||||
else:
|
||||
for key, value in experience_dict.items():
|
||||
if len(value) != experience_count:
|
||||
raise ValueError(f"ERROR: when pack, experience '{key}' number={len(value)} does not match {experience_count=}")
|
||||
|
||||
for key, value in experience_dict.items():
|
||||
# 判断是一维张量还是二维张量
|
||||
is_2d = len(value[0].shape) > 1
|
||||
if is_2d:
|
||||
@ -922,7 +1123,7 @@ def unpack_pad_experience(batch_data, batch_data_length, pad_id, multiple):
|
||||
|
||||
padded_batch_data = {}
|
||||
for key, length_list in batch_data_length.items():
|
||||
if key in ['prompt_length', 'response_length']:
|
||||
if key in ['prompt_length', 'response_length', 'age']:
|
||||
padded_batch_data[key] = batch_data[key].view(-1, 1)
|
||||
continue
|
||||
data = batch_data[key]
|
||||
@ -994,7 +1195,7 @@ def unpack_pad_experience(batch_data, batch_data_length, pad_id, multiple):
|
||||
|
||||
|
||||
def put_prompts_experience(
|
||||
batch: Dict[str, torch.Tensor], n_samples_per_prompt, dataset_additional_keys: List[str] = None, indexes=None,
|
||||
batch: Dict[str, torch.Tensor], n_samples_per_prompt, dataset_additional_keys: List[str] = None, indexes=None, add_another_batch=False,
|
||||
):
|
||||
"""Put data into specified columns and rows.
|
||||
|
||||
@ -1027,7 +1228,10 @@ def put_prompts_experience(
|
||||
for _ in range(n_samples_per_prompt):
|
||||
values.append(value)
|
||||
add_vals[add_keys] = values
|
||||
if indexes is None:
|
||||
prompt_nums = len(prompt_length)
|
||||
if add_another_batch:
|
||||
indexes = [prompt_nums + i for i in range(prompt_nums)]
|
||||
elif indexes is None:
|
||||
indexes = [i for i in range(len(prompt_length))]
|
||||
|
||||
data_dict = dict(
|
||||
|
@ -150,13 +150,16 @@ def truncate_rows(tensor, index_tensor, left_pad=False):
|
||||
truncated_tensors = []
|
||||
|
||||
for i in range(mbs):
|
||||
# 获取当前行的截断索引
|
||||
trunc_idx = index_tensor[i].item()
|
||||
# 截断当前行
|
||||
if left_pad:
|
||||
truncated_row = tensor[i, -trunc_idx:].cpu()
|
||||
if index_tensor[i].item() == 0 and tensor[i, 0].item() == -1:
|
||||
truncated_row = torch.tensor([], dtype=torch.int32).cpu()
|
||||
else:
|
||||
truncated_row = tensor[i, :trunc_idx].cpu()
|
||||
# 获取当前行的截断索引
|
||||
trunc_idx = index_tensor[i].item()
|
||||
# 截断当前行
|
||||
if left_pad:
|
||||
truncated_row = tensor[i, -trunc_idx:].cpu()
|
||||
else:
|
||||
truncated_row = tensor[i, :trunc_idx].cpu()
|
||||
# 将截断后的行添加到列表中
|
||||
truncated_tensors.append(truncated_row)
|
||||
|
||||
|
@ -8,13 +8,12 @@ from enum import Enum
|
||||
from typing import Callable
|
||||
import logging as logger
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from torch import nn
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from mindspeed_rl.config_cls.megatron_config import MegatronConfig
|
||||
from mindspeed_rl.utils.optimizer_module import OptimizerConfig
|
||||
from mindspeed_rl.config_cls.rl_config import RLConfig
|
||||
from mindspeed_rl.config_cls.generate_config import GenerateConfig
|
||||
from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig, MsprobeConfig
|
||||
@ -85,6 +84,7 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
self.actor_profiler = None
|
||||
self.prof_iteration = 1
|
||||
self.idx = 0
|
||||
self.enable_partial_rollout = self.rl_config.partial_rollout_max_split > 1
|
||||
|
||||
def initialize(self):
|
||||
self.setup_distributed_rank()
|
||||
@ -234,16 +234,15 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
|
||||
start_time_defined = False
|
||||
while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0:
|
||||
batch_data, index = self.dispatch_transfer_dock_data(
|
||||
experience_consumer_stage,
|
||||
experience_columns,
|
||||
experience_count,
|
||||
self.megatron_config.tensor_model_parallel_size,
|
||||
self.megatron_config.context_parallel_size,
|
||||
self.megatron_config.context_parallel_algo,
|
||||
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
|
||||
get_n_samples=False
|
||||
)
|
||||
batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage,
|
||||
experience_columns,
|
||||
experience_count,
|
||||
self.megatron_config.tensor_model_parallel_size,
|
||||
self.megatron_config.context_parallel_size,
|
||||
self.megatron_config.context_parallel_algo,
|
||||
indexes=sorted_indexes.pop(
|
||||
0) if self.rl_config.guarantee_order else None,
|
||||
get_n_samples=self.enable_partial_rollout)
|
||||
if not start_time_defined:
|
||||
start_time = time.time()
|
||||
start_time_defined = True
|
||||
@ -287,6 +286,11 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
self.num_floating_point_operations_so_far)
|
||||
self.sharding_manager.exit_train_mode()
|
||||
|
||||
def get_partial_rollout_stop_signal(self):
|
||||
if not self.enable_partial_rollout:
|
||||
return False
|
||||
return ray.get(self.td.get_update_ready.remote(require_max_age_all_finished=self.rl_config.require_max_age_all_finished))
|
||||
|
||||
@mstx_timer_decorator
|
||||
def generate_sequences(self):
|
||||
sharding_infer_interval = 0
|
||||
@ -299,8 +303,14 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
experience_columns = ['prompts', 'prompt_length']
|
||||
if is_multimodal():
|
||||
experience_columns.extend(['input_ids', 'input_ids_length'])
|
||||
if self.enable_partial_rollout:
|
||||
experience_columns.extend(['responses', 'response_length', 'age'])
|
||||
|
||||
experience_count = self.rl_config.actor_rollout_dispatch_size
|
||||
if self.enable_partial_rollout and (self.rl_config.async_engine or self.iteration == self.megatron_config.train_iters - 1):
|
||||
incomplete_resp_num = ray.get(self.td.get_incomplete_response_num.remote())
|
||||
experience_count = int(np.ceil(incomplete_resp_num / self.generate_config.data_parallel_size))
|
||||
else:
|
||||
experience_count = self.rl_config.actor_rollout_dispatch_size
|
||||
|
||||
pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod
|
||||
sorted_indexes = self.get_dp_range_indexes(experience_count,
|
||||
@ -310,7 +320,8 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
profiler_iteration=self.prof_iteration)
|
||||
MsProbe.debugger_start(self.inference_model.model, tag='actor_generate_sequences')
|
||||
|
||||
start_time_defined = False
|
||||
|
||||
start_time = time.time()
|
||||
while self.all_consumed(experience_consumer_stage, sorted_indexes, use_vllm=True) > 0:
|
||||
batch_data, index = self.dispatch_transfer_dock_data(
|
||||
experience_consumer_stage,
|
||||
@ -320,25 +331,28 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
cp_size=self.megatron_config.context_parallel_size,
|
||||
cp_algo=self.megatron_config.context_parallel_algo,
|
||||
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
|
||||
use_vllm=True
|
||||
use_vllm=True,
|
||||
get_n_samples=not self.enable_partial_rollout,
|
||||
enable_partial_rollout=self.enable_partial_rollout
|
||||
)
|
||||
if not start_time_defined:
|
||||
start_time = time.time()
|
||||
start_time_defined = True
|
||||
|
||||
if batch_data and index:
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.rl_config.async_engine:
|
||||
logger.info(f"do async generate process.")
|
||||
prompts_data = batch_data['prompts']
|
||||
prompt_length_data = batch_data['prompt_length']
|
||||
prompts = truncate_rows(prompts_data, prompt_length_data)
|
||||
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
|
||||
self.async_generate_process(experience_count, index, pad_token_id, prompts_list, start_time)
|
||||
self.async_generate_process(batch_data, index, pad_token_id)
|
||||
else:
|
||||
self.sync_generate_process(batch_data, experience_count, index, pad_token_id, start_time)
|
||||
self.sync_generate_process(batch_data, experience_count, index, pad_token_id)
|
||||
if self.enable_partial_rollout:
|
||||
torch.distributed.barrier()
|
||||
end_time = time.time()
|
||||
ray.get(
|
||||
self.td.update_metrics.remote(
|
||||
"timing/rollout",
|
||||
value=[round(end_time, 4), round(start_time, 4)],
|
||||
cumulate=True
|
||||
)
|
||||
)
|
||||
|
||||
profiler_step(actor_generate_profiler)
|
||||
MsProbe.debugger_stop('actor_generate_sequences')
|
||||
|
||||
@ -358,29 +372,58 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
)
|
||||
logger.info("finish generate_sequences")
|
||||
|
||||
def sync_generate_process(self, batch_data, experience_count, index, pad_token_id, start_time):
|
||||
indexes = list(range(0, experience_count, self.rl_config.n_samples_per_prompt))
|
||||
prompts_data = batch_data['prompts'][indexes]
|
||||
prompt_length_data = batch_data['prompt_length'][indexes]
|
||||
# preprocess, remove padding
|
||||
prompts = truncate_rows(prompts_data, prompt_length_data)
|
||||
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
|
||||
def sync_generate_process(self, batch_data, experience_count, index, pad_token_id):
|
||||
if not self.enable_partial_rollout:
|
||||
indexes = list(range(0, experience_count, self.rl_config.n_samples_per_prompt))
|
||||
prompts_data = batch_data['prompts'][indexes]
|
||||
prompt_length_data = batch_data['prompt_length'][indexes]
|
||||
# preprocess, remove padding
|
||||
prompts = truncate_rows(prompts_data, prompt_length_data)
|
||||
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
|
||||
else:
|
||||
prompts_data = batch_data['prompts']
|
||||
prompt_length_data = batch_data['prompt_length']
|
||||
responses = batch_data['responses']
|
||||
responses_length_partial = batch_data['response_length']
|
||||
responses_partial = truncate_rows(responses, responses_length_partial)
|
||||
prompts = truncate_rows(prompts_data, prompt_length_data)
|
||||
prompts_for_vllm = [torch.cat(
|
||||
(prompt, response), dim=0) for prompt, response in
|
||||
zip(prompts, responses_partial)]
|
||||
prompts_list = [prompt.numpy().tolist() for prompt in prompts_for_vllm]
|
||||
if self.enable_partial_rollout:
|
||||
max_tokens = self.generate_config.sampling_config["max_tokens"] // self.rl_config.partial_rollout_max_split
|
||||
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list),
|
||||
max_tokens=max_tokens, n=1,
|
||||
extra_info=batch_data)
|
||||
else:
|
||||
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list),
|
||||
extra_info=batch_data)
|
||||
|
||||
with replace_torch_compile():
|
||||
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), indexes,
|
||||
n_samples_per_prompt=self.rl_config.n_samples_per_prompt,
|
||||
async_engine=self.rl_config.async_engine,
|
||||
extra_info=batch_data)
|
||||
responses = remove_padding_and_split_to_list(responses_pad_right, self.tokenizer.eod, pad_token_id)
|
||||
responses_length = [torch.tensor([len(response)]) for response in responses]
|
||||
|
||||
if is_multimodal():
|
||||
prompts_data = batch_data['input_ids'][indexes].cpu().unbind()
|
||||
else:
|
||||
prompts_data = prompts
|
||||
prompts = []
|
||||
for prompt in prompts_data:
|
||||
for _ in range(self.rl_config.n_samples_per_prompt):
|
||||
prompts.append(copy.deepcopy(prompt))
|
||||
|
||||
if self.enable_partial_rollout:
|
||||
new_responses = []
|
||||
for response_partial, response in zip(responses_partial, responses):
|
||||
new_resp = torch.cat((response_partial, response), dim=0)
|
||||
test_resp = new_resp >= self.tokenizer.vocab_size
|
||||
if test_resp.sum() > 0:
|
||||
new_resp[test_resp] = 0
|
||||
new_responses.append(new_resp)
|
||||
responses = new_responses
|
||||
else:
|
||||
prompts = []
|
||||
for prompt in prompts_data:
|
||||
for _ in range(self.rl_config.n_samples_per_prompt):
|
||||
prompts.append(copy.deepcopy(prompt))
|
||||
|
||||
responses_length = [torch.tensor([len(response)]) for response in responses]
|
||||
|
||||
input_ids_list = []
|
||||
for prompt, response in zip(prompts, responses):
|
||||
input_ids_list.append(torch.cat((prompt, response), dim=0))
|
||||
@ -391,32 +434,74 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
}
|
||||
if is_multimodal():
|
||||
outputs['prompt_length'] = batch_data['input_ids_length']
|
||||
self.collect_transfer_dock_data(outputs, index, use_vllm=True)
|
||||
end_time = time.time()
|
||||
MsProbe.save_data({"responses": responses, "prompts": prompts})
|
||||
ray.get(
|
||||
self.td.update_metrics.remote(
|
||||
"timing/rollout",
|
||||
value=[round(end_time, 4), round(start_time, 4)],
|
||||
cumulate=True
|
||||
)
|
||||
)
|
||||
|
||||
def async_generate_process(self, experience_count, index, pad_token_id, prompts_list, start_time):
|
||||
# inference
|
||||
if self.enable_partial_rollout:
|
||||
finish_status = [torch.tensor([0])] * len(responses_length)
|
||||
for idx, _ in enumerate(responses):
|
||||
if responses[idx][-1] == self.tokenizer.eod or \
|
||||
(prompt_length_data[idx][0] + responses_length[
|
||||
idx][0] >= self.generate_config.max_model_len) or responses_length[
|
||||
idx][0] >= self.generate_config.sampling_config["max_tokens"]:
|
||||
finish_status[idx] = torch.tensor([1])
|
||||
outputs["rollout_completed"] = finish_status
|
||||
|
||||
self.collect_transfer_dock_data(outputs, index, use_vllm=True)
|
||||
MsProbe.save_data({"responses": responses, "prompts": prompts})
|
||||
|
||||
|
||||
def async_generate_process(self, batch_data, index, pad_token_id):
|
||||
self.actor_hybrid.inference_actor.init_cache_engine()
|
||||
with replace_torch_compile():
|
||||
prompts_data = batch_data['prompts']
|
||||
prompt_length_data = batch_data['prompt_length']
|
||||
prompts = truncate_rows(prompts_data, prompt_length_data)
|
||||
if self.enable_partial_rollout:
|
||||
responses = batch_data['responses']
|
||||
responses_length_partial = batch_data['response_length']
|
||||
responses_partial = truncate_rows(responses, responses_length_partial)
|
||||
prompts_for_vllm = [torch.cat((prompt, response), dim=0) for prompt, response in zip(prompts, responses_partial)]
|
||||
prompts_list = [prompt.numpy().tolist() for prompt in prompts_for_vllm]
|
||||
else:
|
||||
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
|
||||
if self.enable_partial_rollout:
|
||||
response_generator = self.actor_hybrid.generate_sequences(
|
||||
copy.deepcopy(prompts_list),
|
||||
indexes=index,
|
||||
n=1,
|
||||
async_engine=True,
|
||||
stop_singal_func=self.get_partial_rollout_stop_signal,
|
||||
)
|
||||
else:
|
||||
response_generator = self.actor_hybrid.generate_sequences(
|
||||
copy.deepcopy(prompts_list),
|
||||
indexes=index,
|
||||
max_tokens=self.generate_config.sampling_config["max_tokens"],
|
||||
n_samples_per_prompt=1,
|
||||
n=1,
|
||||
async_engine=True,
|
||||
)
|
||||
for samples, idx in response_generator:
|
||||
|
||||
for samples, idx_output in response_generator:
|
||||
prompts, responses, log_probs = samples
|
||||
responses = remove_padding_and_split_to_list(responses, self.tokenizer.eod, pad_token_id)
|
||||
|
||||
remove_input_ids = False
|
||||
if self.enable_partial_rollout and len(responses[0]) == 1:
|
||||
remove_input_ids = True
|
||||
|
||||
if self.enable_partial_rollout:
|
||||
responses_partial_new = []
|
||||
prompt_length_new = []
|
||||
for idx in range(len(responses)):
|
||||
iidx = index.index(idx_output[idx])
|
||||
responses_partial_new.append(responses_partial[iidx])
|
||||
prompt_length_new.append(prompt_length_data[iidx])
|
||||
|
||||
new_responses = []
|
||||
for response_partial, response in zip(responses_partial_new, responses):
|
||||
new_resp = torch.cat((response_partial, response), dim=0)
|
||||
test_resp = new_resp >= self.tokenizer.vocab_size
|
||||
if test_resp.sum() > 0:
|
||||
new_resp[test_resp] = 0
|
||||
new_responses.append(new_resp)
|
||||
responses = new_responses
|
||||
responses_length = [torch.tensor([len(response)]) for response in responses]
|
||||
|
||||
input_ids_list = []
|
||||
@ -428,16 +513,20 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
'input_ids': input_ids_list,
|
||||
'response_length': responses_length
|
||||
}
|
||||
self.collect_transfer_dock_data(outputs, idx, use_vllm=True)
|
||||
if remove_input_ids:
|
||||
outputs.pop("input_ids")
|
||||
|
||||
end_time = time.time()
|
||||
ray.get(
|
||||
self.td.update_metrics.remote(
|
||||
"timing/rollout",
|
||||
value=[round(end_time, 4), round(start_time, 4)],
|
||||
cumulate=True
|
||||
)
|
||||
)
|
||||
if self.enable_partial_rollout:
|
||||
finish_status = [torch.tensor([0])] * len(responses_length)
|
||||
for idx, _ in enumerate(responses):
|
||||
if responses[idx][-1] == self.tokenizer.eod or \
|
||||
prompt_length_new[idx][0].to('cpu') + responses_length[
|
||||
idx][0] >= self.generate_config.max_model_len or responses_length[
|
||||
idx][0] >= self.generate_config.sampling_config["max_tokens"]:
|
||||
finish_status[idx] = torch.tensor([1])
|
||||
outputs["rollout_completed"] = finish_status
|
||||
self.collect_transfer_dock_data(outputs, idx_output, use_vllm=True)
|
||||
MsProbe.save_data({"responses": responses, "prompts": prompts})
|
||||
self.actor_hybrid.inference_actor.free_cache_engine()
|
||||
|
||||
@mstx_timer_decorator
|
||||
@ -466,16 +555,15 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
|
||||
start_time_defined = False
|
||||
while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0:
|
||||
batch_data, index = self.dispatch_transfer_dock_data(
|
||||
experience_consumer_stage,
|
||||
experience_columns,
|
||||
experience_count,
|
||||
tp_size=self.megatron_config.tensor_model_parallel_size,
|
||||
cp_size=self.megatron_config.context_parallel_size,
|
||||
cp_algo=self.megatron_config.context_parallel_algo,
|
||||
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
|
||||
get_n_samples=False
|
||||
)
|
||||
batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage,
|
||||
experience_columns,
|
||||
experience_count,
|
||||
tp_size=self.megatron_config.tensor_model_parallel_size,
|
||||
cp_size=self.megatron_config.context_parallel_size,
|
||||
cp_algo=self.megatron_config.context_parallel_algo,
|
||||
indexes=sorted_indexes.pop(
|
||||
0) if self.rl_config.guarantee_order else None,
|
||||
get_n_samples=self.enable_partial_rollout)
|
||||
if not start_time_defined:
|
||||
start_time = time.time()
|
||||
start_time_defined = True
|
||||
|
@ -134,7 +134,6 @@ class BaseWorker(BaseRayWorker, ABC):
|
||||
current_device = next(self.model[0].parameters()).device
|
||||
status = torch.tensor(0, device=current_device)
|
||||
|
||||
rank_flg = False
|
||||
if not use_vllm:
|
||||
rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and
|
||||
get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and
|
||||
@ -225,7 +224,7 @@ class BaseWorker(BaseRayWorker, ABC):
|
||||
def dispatch_transfer_dock_data(self, experience_consumer_stage,
|
||||
experience_columns, experience_count, tp_size=1, cp_size=1, cp_algo=None,
|
||||
use_vllm=False, indexes=None,
|
||||
get_n_samples=True):
|
||||
get_n_samples=True, enable_partial_rollout=False):
|
||||
pad_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod
|
||||
if is_multimodal():
|
||||
mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage))
|
||||
@ -254,6 +253,12 @@ class BaseWorker(BaseRayWorker, ABC):
|
||||
experience_count, indexes=indexes,
|
||||
get_n_samples=get_n_samples,
|
||||
use_batch_seqlen_balance=self.rl_config.use_dp_batch_balance)) # cpu数据
|
||||
elif enable_partial_rollout:
|
||||
# 获取单条数据,不满足的位置补重复样本
|
||||
dp_world_size = self.parallel_state.get_data_parallel_world_size()
|
||||
batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns,
|
||||
experience_count, dp_world_size, indexes=indexes,
|
||||
get_n_samples=get_n_samples)) # cpu数据
|
||||
else:
|
||||
batch_data, index = ray.get(
|
||||
self.td.get_experience.remote(experience_consumer_stage, experience_columns,
|
||||
@ -290,7 +295,10 @@ class BaseWorker(BaseRayWorker, ABC):
|
||||
return None, None
|
||||
|
||||
if rank_flg:
|
||||
batch_data, batch_data_length = pack_experience_columns(batch_data, experience_count)
|
||||
batch_data, batch_data_length = pack_experience_columns(experience_consumer_stage, batch_data,
|
||||
experience_count,
|
||||
enable_partial_rollout=enable_partial_rollout,
|
||||
)
|
||||
|
||||
for key in experience_columns:
|
||||
if rank_flg:
|
||||
|
@ -118,7 +118,7 @@ class ReferenceWorkerBase(BaseWorker):
|
||||
cp_algo=self.megatron_config.context_parallel_algo,
|
||||
indexes=sorted_indexes.pop(
|
||||
0) if self.rl_config.guarantee_order else None,
|
||||
get_n_samples=False)
|
||||
get_n_samples=self.rl_config.partial_rollout_max_split > 1)
|
||||
|
||||
if not start_time_defined:
|
||||
start_time = time.time()
|
||||
|
@ -108,6 +108,7 @@ class RewardWorkerBase(BaseWorker):
|
||||
cp_algo=self.megatron_config.context_parallel_algo,
|
||||
indexes=sorted_indexes.pop(
|
||||
0) if self.rl_config.guarantee_order else None,
|
||||
get_n_samples=self.rl_config.partial_rollout_max_split > 1
|
||||
)
|
||||
if not start_time_defined:
|
||||
start_time = time.time()
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
|
||||
import ray
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
|
||||
@ -46,7 +45,8 @@ class RuleReward(object):
|
||||
experience_consumer_stage,
|
||||
experience_columns,
|
||||
experience_count,
|
||||
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None
|
||||
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
|
||||
get_n_samples=True
|
||||
)
|
||||
) # cpu数据
|
||||
batch_data = remove_padding_tensor_dict_to_dict(batch_data)
|
||||
|
BIN
sources/images/partial_rollout/async.png
Normal file
BIN
sources/images/partial_rollout/async.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 105 KiB |
BIN
sources/images/partial_rollout/async_1.png
Normal file
BIN
sources/images/partial_rollout/async_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 221 KiB |
BIN
sources/images/partial_rollout/sync.png
Normal file
BIN
sources/images/partial_rollout/sync.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 89 KiB |
BIN
sources/images/partial_rollout/sync_1.png
Normal file
BIN
sources/images/partial_rollout/sync_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 152 KiB |
BIN
sources/images/partial_rollout/sync_partial_compare_result.png
Normal file
BIN
sources/images/partial_rollout/sync_partial_compare_result.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 698 KiB |
Reference in New Issue
Block a user