!203 openmind支持reward训练
Merge pull request !203 from 幽若/master-reward-pr
This commit is contained in:
35
examples/features/train/train_rm_lora.yaml
Normal file
35
examples/features/train/train_rm_lora.yaml
Normal file
@ -0,0 +1,35 @@
|
||||
model_name_or_path: Qwen2.5-7B
|
||||
|
||||
# method
|
||||
stage: rm
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
|
||||
template: qwen
|
||||
|
||||
deepspeed: examples/deepspeed/ds_z2_config.json
|
||||
|
||||
trust_remote_code: True
|
||||
|
||||
# dataset
|
||||
dataset: rlhf-reward-datasets
|
||||
cutoff_len: 1024
|
||||
max_length: 1024
|
||||
|
||||
# output
|
||||
output_dir: saves
|
||||
logging_steps: 1
|
||||
save_steps: 20000
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
seed: 1234
|
||||
|
||||
save_strategy: "no"
|
@ -65,6 +65,7 @@ pt-cpu = [
|
||||
"lm_eval == 0.4.3",
|
||||
"diffusers >= 0.29.0, <= 0.31.0",
|
||||
"peft >= 0.12.0",
|
||||
"trl == 0.9.3",
|
||||
]
|
||||
|
||||
pt = [
|
||||
|
@ -14,7 +14,7 @@
|
||||
import sys
|
||||
|
||||
from openmind.flow.arguments import get_args, initialize_openmind
|
||||
from openmind.flow.train import run_sft, run_pt, run_dpo
|
||||
from openmind.flow.train import run_sft, run_pt, run_dpo, run_rm
|
||||
from openmind.flow.callbacks import get_swanlab_callbacks
|
||||
from openmind.utils.constants import Stages
|
||||
|
||||
@ -37,6 +37,8 @@ def run_train(**kwargs):
|
||||
run_pt(callbacks)
|
||||
elif args.stage == Stages.DPO:
|
||||
run_dpo()
|
||||
elif args.stage == Stages.RM:
|
||||
run_rm()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -36,6 +36,7 @@ else:
|
||||
|
||||
if is_trl_available():
|
||||
from trl.trainer.dpo_config import DPOConfig
|
||||
from trl.trainer.reward_config import RewardConfig
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, LoraConfig
|
||||
@ -120,8 +121,9 @@ def parse_args(yaml_path=None, ignore_unknown_args=False, custom_args=None):
|
||||
validate_args(args)
|
||||
add_special_args(args)
|
||||
|
||||
# add rlhf arguments (ppo/dpo)
|
||||
# add rlhf arguments (ppo/dpo/rm)
|
||||
add_dpo_args(args)
|
||||
add_reward_args(args)
|
||||
|
||||
return args
|
||||
|
||||
@ -151,6 +153,15 @@ def add_special_args(args):
|
||||
setattr(args, "hf_seq2seq_args", seq2seq_args)
|
||||
|
||||
|
||||
def add_reward_args(args):
|
||||
# add RewardConfig from trl package
|
||||
reward_args = None
|
||||
if is_trl_available():
|
||||
hf_parser = HfArgumentParser(RewardConfig)
|
||||
reward_args = hf_parser.parse_dict(vars(args), allow_extra_keys=True)[0]
|
||||
setattr(args, "reward_args", reward_args)
|
||||
|
||||
|
||||
def add_dpo_args(args):
|
||||
# add DPOConfig from trl package
|
||||
dpo_args = None
|
||||
@ -215,7 +226,7 @@ def validate_args(args):
|
||||
raise ValueError("The version of transformers is required at least 4.45.0 to run quantization.")
|
||||
|
||||
# stage and finetune type
|
||||
valid_stages = [Stages.SFT, Stages.PT, Stages.DPO]
|
||||
valid_stages = [Stages.SFT, Stages.PT, Stages.DPO, Stages.RM]
|
||||
if args.stage not in valid_stages:
|
||||
raise ValueError(f"Currently supported stage list is {valid_stages}")
|
||||
|
||||
|
@ -68,6 +68,18 @@
|
||||
"system_tag": "system"
|
||||
}
|
||||
},
|
||||
"rlhf-reward-datasets": {
|
||||
"hub_url": {
|
||||
"modelers": "PyTorch-NPU/rlhf-reward-datasets"
|
||||
},
|
||||
"formatting": "pairwise",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "prompt",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
},
|
||||
"OpenR1-Math-220k_filtered_step3_SFT": {
|
||||
"hub_url": {
|
||||
"modelers": "openmind/OpenR1-Math-220k_filtered_step3_SFT"
|
||||
|
@ -27,6 +27,7 @@ from openmind.flow.datasets.preprocess import (
|
||||
preprocess_supervised_dataset,
|
||||
preprocess_pretrain_dataset,
|
||||
preprocess_pairwise_dataset,
|
||||
preprocess_reward_dataset,
|
||||
)
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.datasets.template import Template
|
||||
@ -165,6 +166,9 @@ def _get_preprocessed_dataset(
|
||||
)
|
||||
preprocess_func = _get_preprocess_func(template, tokenizer, processor)
|
||||
logger.info_rank0(f"\n******removed columes: {column_names} *********\n")
|
||||
if args.stage == Stages.RM:
|
||||
dataset = preprocess_func(dataset=dataset, args=args)
|
||||
else:
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
batched=True,
|
||||
@ -177,6 +181,9 @@ def _get_preprocessed_dataset(
|
||||
if args.stage in [Stages.SFT, Stages.PT]:
|
||||
logger.info_rank0("\ninput:\n{}".format(tokenizer.decode(dataset["input_ids"][0])))
|
||||
logger.info_rank0("\ninput_ids:\n{}\n".format(dataset["input_ids"][0]))
|
||||
if args.stage == Stages.RM:
|
||||
logger.info_rank0("\nchosen input:\n{}".format(dataset["chosen"][0]))
|
||||
logger.info_rank0("\nrejected input:\n{}".format(dataset["rejected"][0]))
|
||||
return dataset
|
||||
|
||||
|
||||
@ -193,6 +200,9 @@ def _get_preprocess_func(template, tokenizer, processor):
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset, template=template, tokenizer=tokenizer, cutoff_len=args.cutoff_len
|
||||
)
|
||||
elif args.stage == Stages.RM:
|
||||
preprocess_func = partial(preprocess_reward_dataset, tokenizer=tokenizer)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return preprocess_func
|
||||
|
@ -102,7 +102,8 @@ def get_dataset_attr(dataset: Optional[str], dataset_info) -> "InstructionDatase
|
||||
|
||||
if "file_name" in dataset_info[dataset]:
|
||||
dataset_attr.set_attr("file_name", dataset_info[dataset], default=None)
|
||||
|
||||
if "ranking" in dataset_info[dataset]:
|
||||
dataset_attr.set_attr("ranking", dataset_info[dataset], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[dataset], default="alpaca")
|
||||
dataset_attr.set_attr("is_custom", dataset_info[dataset], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[dataset])
|
||||
|
@ -20,6 +20,7 @@ if TYPE_CHECKING:
|
||||
from .sft import preprocess_supervised_dataset
|
||||
from .pt import preprocess_pretrain_dataset
|
||||
from .dpo import preprocess_pairwise_dataset
|
||||
from .rm import preprocess_reward_dataset
|
||||
from .generic import align_dataset, merge_datasets
|
||||
else:
|
||||
import sys
|
||||
@ -28,6 +29,7 @@ else:
|
||||
"sft": ["preprocess_supervised_dataset"],
|
||||
"pt": ["preprocess_pretrain_dataset"],
|
||||
"dpo": ["preprocess_pairwise_dataset"],
|
||||
"rm": ["preprocess_reward_dataset"],
|
||||
"generic": ["align_dataset", "merge_datasets"],
|
||||
}
|
||||
|
||||
|
76
src/openmind/flow/datasets/preprocess/rm.py
Normal file
76
src/openmind/flow/datasets/preprocess/rm.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional, Union, Callable, List, Dict
|
||||
|
||||
import trl
|
||||
from accelerate import PartialState
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
|
||||
def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") -> Dict[str, List[Any]]:
|
||||
"""Tokenize a batch from a reward modelling dataset."""
|
||||
new_examples = {
|
||||
"input_ids_chosen": [],
|
||||
"attention_mask_chosen": [],
|
||||
"input_ids_rejected": [],
|
||||
"attention_mask_rejected": [],
|
||||
}
|
||||
for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
|
||||
tokenized_chosen = tokenizer(chosen)
|
||||
tokenized_rejected = tokenizer(rejected)
|
||||
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
|
||||
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
|
||||
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
|
||||
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
|
||||
|
||||
return new_examples
|
||||
|
||||
|
||||
def _apply_chat_template(
|
||||
example,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
tools: Optional[List[Union[Dict, Callable]]] = None,
|
||||
) -> Dict[str, str]:
|
||||
new_example = {
|
||||
"prompt": example["_prompt"],
|
||||
"chosen": [example["_response"][0][0]],
|
||||
"rejected": [example["_response"][0][1]],
|
||||
}
|
||||
return trl.data_utils.maybe_apply_chat_template(tokenizer=tokenizer, example=new_example, tools=tools)
|
||||
|
||||
|
||||
def preprocess_reward_dataset(dataset, tokenizer, args):
|
||||
with PartialState().main_process_first():
|
||||
fn_kwargs = {"tokenizer": tokenizer}
|
||||
max_length = args.max_length
|
||||
if dataset is not None:
|
||||
dataset = dataset.map(_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
|
||||
dataset = dataset.map(
|
||||
_tokenize,
|
||||
fn_kwargs=fn_kwargs,
|
||||
batched=True,
|
||||
num_proc=None,
|
||||
)
|
||||
# This filter is important because otherwise you get samples that exceed the model's context length and
|
||||
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
||||
# user might get surprised if N samples are missing from training.
|
||||
if max_length is not None:
|
||||
dataset = dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
||||
num_proc=None,
|
||||
)
|
||||
return dataset
|
||||
return dataset
|
@ -339,7 +339,7 @@ def get_model():
|
||||
if args.do_train:
|
||||
model.train()
|
||||
|
||||
if args.use_gradient_checkpointing and model.supports_gradient_checkpointing:
|
||||
if args.use_gradient_checkpointing and model.supports_gradient_checkpointing and args.stage != "rm":
|
||||
model.gradient_checkpointing_enable()
|
||||
logger.info_rank0("Gradient checkpointing has been enabled.")
|
||||
else:
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .sft import run_sft
|
||||
from .pt import run_pt
|
||||
from .dpo import run_dpo
|
||||
from .rm import run_rm
|
||||
|
1
src/openmind/flow/train/rm/__init__.py
Normal file
1
src/openmind/flow/train/rm/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .workflow import run_rm
|
151
src/openmind/flow/train/rm/workflow.py
Normal file
151
src/openmind/flow/train/rm/workflow.py
Normal file
@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Optional, Union, Dict
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase, TrainerCallback
|
||||
import numpy as np
|
||||
from trl.trainer import reward_trainer
|
||||
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.datasets import get_template, get_dataset_module
|
||||
from openmind.flow.model import get_model, get_tokenizer
|
||||
from openmind.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardDataCollatorWithPadding:
|
||||
r"""
|
||||
Reward DataCollator class that pads the inputs to the maximum length of the batch.
|
||||
|
||||
Args:
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
|
||||
padding_strategy to pass to the tokenizer.
|
||||
pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
return_tensors (`str`, `optional`, defaults to `"pt"`):
|
||||
The tensor type to use.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
features_chosen = []
|
||||
features_rejected = []
|
||||
margin = []
|
||||
# check if we have a margin. If we do, we need to batch it as well
|
||||
has_margin = "margin" in features[0]
|
||||
max_length = 0
|
||||
for feature in features:
|
||||
# check if the keys are named as expected
|
||||
max_length = max(max_length, len(feature["input_ids_chosen"]), len(feature["input_ids_rejected"]))
|
||||
keys_exist = (
|
||||
"input_ids_chosen" in feature
|
||||
and "input_ids_rejected" in feature
|
||||
and "attention_mask_chosen" in feature
|
||||
and "attention_mask_rejected" in feature
|
||||
)
|
||||
if not keys_exist:
|
||||
raise ValueError(
|
||||
"The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`"
|
||||
)
|
||||
|
||||
features_chosen.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_chosen"],
|
||||
"attention_mask": feature["attention_mask_chosen"],
|
||||
}
|
||||
)
|
||||
features_rejected.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_rejected"],
|
||||
"attention_mask": feature["attention_mask_rejected"],
|
||||
}
|
||||
)
|
||||
if has_margin:
|
||||
margin.append(feature["margin"])
|
||||
batch_chosen = self.tokenizer.pad(
|
||||
features_chosen,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch_rejected = self.tokenizer.pad(
|
||||
features_rejected,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch = {
|
||||
"input_ids_chosen": batch_chosen["input_ids"],
|
||||
"attention_mask_chosen": batch_chosen["attention_mask"],
|
||||
"input_ids_rejected": batch_rejected["input_ids"],
|
||||
"attention_mask_rejected": batch_rejected["attention_mask"],
|
||||
"return_loss": True,
|
||||
}
|
||||
if has_margin:
|
||||
margin = torch.tensor(margin, dtype=torch.float)
|
||||
batch["margin"] = margin
|
||||
return batch
|
||||
|
||||
|
||||
def run_rm(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
args = get_args()
|
||||
setup_seed(args.seed)
|
||||
tokenizer = get_tokenizer()
|
||||
model = get_model()
|
||||
model.find_unused_parameters = True
|
||||
|
||||
template = get_template()
|
||||
dataset_module = get_dataset_module(tokenizer, template)
|
||||
|
||||
train_args = args.reward_args
|
||||
train_args.remove_unused_columns = False
|
||||
|
||||
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer)
|
||||
trainer = reward_trainer.RewardTrainer(
|
||||
model=model,
|
||||
args=train_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
processing_class=tokenizer,
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
logger.info_rank0("Start training.")
|
||||
train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
Reference in New Issue
Block a user