mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
184 lines
8.0 KiB
Python
184 lines
8.0 KiB
Python
from typing import Dict, Callable, Optional, Any
|
|
|
|
import numpy as np
|
|
|
|
from mindspeed_rl.datasets.utils import _infer_seqlen, get_prompt_index
|
|
|
|
from mindspeed_rl.datasets.indexed_dataset import get_packed_indexed_dataset
|
|
from mindspeed_rl.datasets.base_dataset import BaseDataset
|
|
from mindspeed_rl.datasets.templates import get_model_template
|
|
from mindspeed_rl.datasets.utils import _build_index_mappings
|
|
|
|
|
|
class PromptDataset(BaseDataset):
|
|
def __init__(
|
|
self,
|
|
data_prefix: str = "",
|
|
is_packed_data: bool = False,
|
|
tokenizer: Callable = None,
|
|
seq_length: int = 128,
|
|
num_samples: int = None,
|
|
name: str = "",
|
|
documents: Any = None,
|
|
seed: int = 42,
|
|
full_shuffle_instruction_dataset: bool = False,
|
|
token_param: Optional[Dict] = None,
|
|
preprocess_template: Optional[str] = None,
|
|
pad_token: int = 0,
|
|
eos_token: int = 1,
|
|
extra_param: Any = None,
|
|
**kwargs,
|
|
):
|
|
self.data_prefix = data_prefix
|
|
self.is_packed_data = is_packed_data
|
|
self.tokenizer = tokenizer
|
|
self.token_param = token_param
|
|
self.seq_length = seq_length
|
|
self.preprocess_template = preprocess_template
|
|
self.pad_token = pad_token
|
|
self.eos_token = eos_token
|
|
self.num_samples = num_samples
|
|
self.args = extra_param
|
|
|
|
if self.is_packed_data:
|
|
self.res_dataset = get_packed_indexed_dataset(data_prefix=self.data_prefix,
|
|
filter_length=getattr(extra_param, 'max_prompt_length', None),
|
|
is_pairwise_dataset=self.args.is_pairwise_dataset)
|
|
self.shuffle_index = _build_index_mappings(name=name,
|
|
data_prefix=self.data_prefix,
|
|
start_index=documents[0],
|
|
nb_documents=len(documents),
|
|
num_samples=self.num_samples,
|
|
seed=seed,
|
|
full_shuffle_instruction_dataset=full_shuffle_instruction_dataset,
|
|
parallel_state=kwargs.get('parallel_state'),
|
|
no_shuffle=True)
|
|
dataset_type = "Prompt_DS_Packed"
|
|
else:
|
|
raise NotImplementedError('non packed data are not supported yet.')
|
|
|
|
super().__init__(self.res_dataset, dataset_type)
|
|
|
|
def __len__(self):
|
|
return len(self.shuffle_index)
|
|
|
|
def __getitem__(self, index):
|
|
doc_idx = self.shuffle_index[index]
|
|
|
|
item = self.res_dataset[doc_idx]
|
|
if self.args.is_pairwise_dataset:
|
|
return self._cut_pairwise_token(item, np.int64)
|
|
return self._cut_instruction_token(item, np.int64)
|
|
|
|
def _cut_instruction_token(self, item, dtype):
|
|
IGNORE_INDEX = -100
|
|
if "labels" in item.keys() and not self.args.dataset_additional_keys:
|
|
token_length = len(item["input_ids"])
|
|
if token_length <= self.seq_length:
|
|
return {
|
|
"input_ids": item["input_ids"].astype(dtype),
|
|
"attention_mask": np.ones_like(item["input_ids"]).astype(dtype),
|
|
"labels": item["labels"].astype(dtype)
|
|
}
|
|
|
|
template = None
|
|
# get model chat template
|
|
if hasattr(self.args, "prompt_type") and self.args.prompt_type is not None:
|
|
template = get_model_template(self.args.prompt_type, self.args.prompt_type_path, self.args.enable_thinking)
|
|
|
|
prompt_begin_list, prompt_end_list = get_prompt_index(item["labels"], IGNORE_INDEX)
|
|
|
|
multi_turns = len(prompt_begin_list)
|
|
total_length = 0
|
|
|
|
if template is not None and template.efficient_eos:
|
|
total_length = 1
|
|
prompt_end_list = [x - 1 for x in prompt_end_list]
|
|
eos_token_id = item["input_ids"][token_length - 1]
|
|
item["input_ids"] = item["input_ids"][:token_length]
|
|
item["labels"] = item["labels"][:token_length]
|
|
|
|
cutoff_len = self.seq_length
|
|
input_ids = np.array([], dtype=dtype)
|
|
labels = np.array([], dtype=dtype)
|
|
|
|
for turn_idx in range(multi_turns):
|
|
if total_length >= cutoff_len:
|
|
break
|
|
source_ids = item["input_ids"][prompt_begin_list[turn_idx]:prompt_end_list[turn_idx]]
|
|
mask_ids = item["labels"][prompt_begin_list[turn_idx]:prompt_end_list[turn_idx]]
|
|
|
|
label_begin_idx = prompt_end_list[turn_idx]
|
|
|
|
if turn_idx != multi_turns - 1:
|
|
target_ids = item["labels"][label_begin_idx:prompt_begin_list[turn_idx + 1]]
|
|
else:
|
|
target_ids = item["labels"][label_begin_idx:]
|
|
|
|
source_len, target_len = _infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
|
|
|
|
source_ids = source_ids[:source_len]
|
|
target_ids = target_ids[:target_len]
|
|
mask_ids = mask_ids[:source_len]
|
|
|
|
total_length += source_len + target_len
|
|
input_ids = np.concatenate((input_ids, source_ids, target_ids), axis=0)
|
|
labels = np.concatenate((labels, mask_ids, target_ids), axis=0)
|
|
|
|
if template is not None and template.efficient_eos:
|
|
input_ids = np.concatenate((input_ids, np.array([eos_token_id], dtype=dtype)), axis=0)
|
|
labels = np.concatenate((labels, np.array([eos_token_id], dtype=dtype)), axis=0)
|
|
|
|
res = {
|
|
"input_ids": input_ids.astype(dtype),
|
|
"attention_mask": np.ones_like(input_ids).astype(dtype),
|
|
"labels": labels.astype(dtype)
|
|
}
|
|
|
|
else:
|
|
prompt_ids = item["input_ids"]
|
|
input_ids = prompt_ids[:self.seq_length]
|
|
|
|
add_vals = {}
|
|
for add_keys in self.args.dataset_additional_keys:
|
|
if add_keys in item.keys():
|
|
add_vals[add_keys] = item[add_keys]
|
|
|
|
res = dict(
|
|
{
|
|
"input_ids": input_ids.astype(dtype),
|
|
"attention_mask": np.ones_like(input_ids).astype(dtype)
|
|
}, **add_vals
|
|
)
|
|
|
|
return res
|
|
|
|
def _cut_pairwise_token(self, item, dtype):
|
|
"""Cut prompt and response proportionally for pairwise datasets."""
|
|
IGNORE_INDEX = -100
|
|
prompt_length = (item["chosen_labels"] != IGNORE_INDEX).nonzero()[0][0]
|
|
prompt_ids = item["chosen_input_ids"][:prompt_length]
|
|
chosen_ids = item["chosen_input_ids"][prompt_length:]
|
|
rejected_ids = item["rejected_input_ids"][prompt_length:]
|
|
source_len, target_len = _infer_seqlen(
|
|
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.seq_length
|
|
)
|
|
prompt_ids = prompt_ids[:source_len]
|
|
chosen_ids = chosen_ids[:target_len]
|
|
rejected_ids = rejected_ids[:target_len]
|
|
|
|
chosen_input_ids = np.append(prompt_ids, chosen_ids)
|
|
chosen_labels = np.append(IGNORE_INDEX * np.ones(source_len), chosen_ids)
|
|
rejected_input_ids = np.append(prompt_ids, rejected_ids)
|
|
rejected_labels = np.append(IGNORE_INDEX * np.ones(source_len), rejected_ids)
|
|
|
|
res = {
|
|
"chosen_input_ids": chosen_input_ids.astype(dtype),
|
|
"chosen_attention_mask": np.ones_like(chosen_input_ids).astype(dtype),
|
|
"chosen_labels": chosen_labels.astype(dtype),
|
|
"rejected_input_ids": rejected_input_ids.astype(dtype),
|
|
"rejected_attention_mask": np.ones_like(rejected_input_ids).astype(dtype),
|
|
"rejected_labels": rejected_labels.astype(dtype)
|
|
}
|
|
|
|
return res |