!203 openmind支持reward训练

Merge pull request !203 from 幽若/master-reward-pr
This commit is contained in:
2025-05-09 02:02:22 +00:00
committed by i-robot
parent 2945b4969e
commit 2ede23881f
13 changed files with 315 additions and 12 deletions

View File

@ -0,0 +1,35 @@
model_name_or_path: Qwen2.5-7B
# method
stage: rm
do_train: true
finetuning_type: lora
template: qwen
deepspeed: examples/deepspeed/ds_z2_config.json
trust_remote_code: True
# dataset
dataset: rlhf-reward-datasets
cutoff_len: 1024
max_length: 1024
# output
output_dir: saves
logging_steps: 1
save_steps: 20000
overwrite_output_dir: true
# train
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
max_steps: 10
seed: 1234
save_strategy: "no"

View File

@ -65,6 +65,7 @@ pt-cpu = [
"lm_eval == 0.4.3",
"diffusers >= 0.29.0, <= 0.31.0",
"peft >= 0.12.0",
"trl == 0.9.3",
]
pt = [

View File

@ -14,7 +14,7 @@
import sys
from openmind.flow.arguments import get_args, initialize_openmind
from openmind.flow.train import run_sft, run_pt, run_dpo
from openmind.flow.train import run_sft, run_pt, run_dpo, run_rm
from openmind.flow.callbacks import get_swanlab_callbacks
from openmind.utils.constants import Stages
@ -37,6 +37,8 @@ def run_train(**kwargs):
run_pt(callbacks)
elif args.stage == Stages.DPO:
run_dpo()
elif args.stage == Stages.RM:
run_rm()
if __name__ == "__main__":

View File

@ -36,6 +36,7 @@ else:
if is_trl_available():
from trl.trainer.dpo_config import DPOConfig
from trl.trainer.reward_config import RewardConfig
if is_peft_available():
from peft import PeftConfig, LoraConfig
@ -120,8 +121,9 @@ def parse_args(yaml_path=None, ignore_unknown_args=False, custom_args=None):
validate_args(args)
add_special_args(args)
# add rlhf arguments (ppo/dpo)
# add rlhf arguments (ppo/dpo/rm)
add_dpo_args(args)
add_reward_args(args)
return args
@ -151,6 +153,15 @@ def add_special_args(args):
setattr(args, "hf_seq2seq_args", seq2seq_args)
def add_reward_args(args):
# add RewardConfig from trl package
reward_args = None
if is_trl_available():
hf_parser = HfArgumentParser(RewardConfig)
reward_args = hf_parser.parse_dict(vars(args), allow_extra_keys=True)[0]
setattr(args, "reward_args", reward_args)
def add_dpo_args(args):
# add DPOConfig from trl package
dpo_args = None
@ -215,7 +226,7 @@ def validate_args(args):
raise ValueError("The version of transformers is required at least 4.45.0 to run quantization.")
# stage and finetune type
valid_stages = [Stages.SFT, Stages.PT, Stages.DPO]
valid_stages = [Stages.SFT, Stages.PT, Stages.DPO, Stages.RM]
if args.stage not in valid_stages:
raise ValueError(f"Currently supported stage list is {valid_stages}")

View File

@ -68,6 +68,18 @@
"system_tag": "system"
}
},
"rlhf-reward-datasets": {
"hub_url": {
"modelers": "PyTorch-NPU/rlhf-reward-datasets"
},
"formatting": "pairwise",
"ranking": true,
"columns": {
"prompt": "prompt",
"chosen": "chosen",
"rejected": "rejected"
}
},
"OpenR1-Math-220k_filtered_step3_SFT": {
"hub_url": {
"modelers": "openmind/OpenR1-Math-220k_filtered_step3_SFT"

View File

@ -27,6 +27,7 @@ from openmind.flow.datasets.preprocess import (
preprocess_supervised_dataset,
preprocess_pretrain_dataset,
preprocess_pairwise_dataset,
preprocess_reward_dataset,
)
from openmind.flow.arguments import get_args
from openmind.flow.datasets.template import Template
@ -165,18 +166,24 @@ def _get_preprocessed_dataset(
)
preprocess_func = _get_preprocess_func(template, tokenizer, processor)
logger.info_rank0(f"\n******removed columes: {column_names} *********\n")
dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=args.preprocessing_batch_size,
remove_columns=column_names,
**preprocess_kwargs,
)
if args.stage == Stages.RM:
dataset = preprocess_func(dataset=dataset, args=args)
else:
dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=args.preprocessing_batch_size,
remove_columns=column_names,
**preprocess_kwargs,
)
logger.info_rank0(f"\n******processed new columes: {dataset.column_names} *********\n")
# print datasets example applied template
if args.stage in [Stages.SFT, Stages.PT]:
logger.info_rank0("\ninput:\n{}".format(tokenizer.decode(dataset["input_ids"][0])))
logger.info_rank0("\ninput_ids:\n{}\n".format(dataset["input_ids"][0]))
if args.stage == Stages.RM:
logger.info_rank0("\nchosen input:\n{}".format(dataset["chosen"][0]))
logger.info_rank0("\nrejected input:\n{}".format(dataset["rejected"][0]))
return dataset
@ -193,6 +200,9 @@ def _get_preprocess_func(template, tokenizer, processor):
preprocess_func = partial(
preprocess_pairwise_dataset, template=template, tokenizer=tokenizer, cutoff_len=args.cutoff_len
)
elif args.stage == Stages.RM:
preprocess_func = partial(preprocess_reward_dataset, tokenizer=tokenizer)
else:
raise NotImplementedError
return preprocess_func

View File

@ -102,7 +102,8 @@ def get_dataset_attr(dataset: Optional[str], dataset_info) -> "InstructionDatase
if "file_name" in dataset_info[dataset]:
dataset_attr.set_attr("file_name", dataset_info[dataset], default=None)
if "ranking" in dataset_info[dataset]:
dataset_attr.set_attr("ranking", dataset_info[dataset], default=False)
dataset_attr.set_attr("formatting", dataset_info[dataset], default="alpaca")
dataset_attr.set_attr("is_custom", dataset_info[dataset], default=False)
dataset_attr.set_attr("subset", dataset_info[dataset])

View File

@ -20,6 +20,7 @@ if TYPE_CHECKING:
from .sft import preprocess_supervised_dataset
from .pt import preprocess_pretrain_dataset
from .dpo import preprocess_pairwise_dataset
from .rm import preprocess_reward_dataset
from .generic import align_dataset, merge_datasets
else:
import sys
@ -28,6 +29,7 @@ else:
"sft": ["preprocess_supervised_dataset"],
"pt": ["preprocess_pretrain_dataset"],
"dpo": ["preprocess_pairwise_dataset"],
"rm": ["preprocess_reward_dataset"],
"generic": ["align_dataset", "merge_datasets"],
}

View File

@ -0,0 +1,76 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union, Callable, List, Dict
import trl
from accelerate import PartialState
from transformers import PreTrainedTokenizerBase
def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") -> Dict[str, List[Any]]:
"""Tokenize a batch from a reward modelling dataset."""
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
def _apply_chat_template(
example,
tokenizer: PreTrainedTokenizerBase,
tools: Optional[List[Union[Dict, Callable]]] = None,
) -> Dict[str, str]:
new_example = {
"prompt": example["_prompt"],
"chosen": [example["_response"][0][0]],
"rejected": [example["_response"][0][1]],
}
return trl.data_utils.maybe_apply_chat_template(tokenizer=tokenizer, example=new_example, tools=tools)
def preprocess_reward_dataset(dataset, tokenizer, args):
with PartialState().main_process_first():
fn_kwargs = {"tokenizer": tokenizer}
max_length = args.max_length
if dataset is not None:
dataset = dataset.map(_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset = dataset.map(
_tokenize,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=None,
)
# This filter is important because otherwise you get samples that exceed the model's context length and
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
# user might get surprised if N samples are missing from training.
if max_length is not None:
dataset = dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
num_proc=None,
)
return dataset
return dataset

View File

@ -339,7 +339,7 @@ def get_model():
if args.do_train:
model.train()
if args.use_gradient_checkpointing and model.supports_gradient_checkpointing:
if args.use_gradient_checkpointing and model.supports_gradient_checkpointing and args.stage != "rm":
model.gradient_checkpointing_enable()
logger.info_rank0("Gradient checkpointing has been enabled.")
else:

View File

@ -1,3 +1,4 @@
from .sft import run_sft
from .pt import run_pt
from .dpo import run_dpo
from .rm import run_rm

View File

@ -0,0 +1 @@
from .workflow import run_rm

View File

@ -0,0 +1,151 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
#
# openMind is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import random
from dataclasses import dataclass
from typing import List, Any, Optional, Union, Dict
import torch
from transformers import PreTrainedTokenizerBase, TrainerCallback
import numpy as np
from trl.trainer import reward_trainer
from openmind.flow.arguments import get_args
from openmind.flow.datasets import get_template, get_dataset_module
from openmind.flow.model import get_model, get_tokenizer
from openmind.utils import get_logger
logger = get_logger(__name__)
def setup_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
@dataclass
class RewardDataCollatorWithPadding:
r"""
Reward DataCollator class that pads the inputs to the maximum length of the batch.
Args:
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used for encoding the data.
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
padding_strategy to pass to the tokenizer.
pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`):
If set will pad the sequence to a multiple of the provided value.
return_tensors (`str`, `optional`, defaults to `"pt"`):
The tensor type to use.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str] = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
features_chosen = []
features_rejected = []
margin = []
# check if we have a margin. If we do, we need to batch it as well
has_margin = "margin" in features[0]
max_length = 0
for feature in features:
# check if the keys are named as expected
max_length = max(max_length, len(feature["input_ids_chosen"]), len(feature["input_ids_rejected"]))
keys_exist = (
"input_ids_chosen" in feature
and "input_ids_rejected" in feature
and "attention_mask_chosen" in feature
and "attention_mask_rejected" in feature
)
if not keys_exist:
raise ValueError(
"The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`"
)
features_chosen.append(
{
"input_ids": feature["input_ids_chosen"],
"attention_mask": feature["attention_mask_chosen"],
}
)
features_rejected.append(
{
"input_ids": feature["input_ids_rejected"],
"attention_mask": feature["attention_mask_rejected"],
}
)
if has_margin:
margin.append(feature["margin"])
batch_chosen = self.tokenizer.pad(
features_chosen,
padding="max_length",
max_length=max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch_rejected = self.tokenizer.pad(
features_rejected,
padding="max_length",
max_length=max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = {
"input_ids_chosen": batch_chosen["input_ids"],
"attention_mask_chosen": batch_chosen["attention_mask"],
"input_ids_rejected": batch_rejected["input_ids"],
"attention_mask_rejected": batch_rejected["attention_mask"],
"return_loss": True,
}
if has_margin:
margin = torch.tensor(margin, dtype=torch.float)
batch["margin"] = margin
return batch
def run_rm(
callbacks: Optional[List["TrainerCallback"]] = None,
):
args = get_args()
setup_seed(args.seed)
tokenizer = get_tokenizer()
model = get_model()
model.find_unused_parameters = True
template = get_template()
dataset_module = get_dataset_module(tokenizer, template)
train_args = args.reward_args
train_args.remove_unused_columns = False
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer)
trainer = reward_trainer.RewardTrainer(
model=model,
args=train_args,
data_collator=data_collator,
callbacks=callbacks,
processing_class=tokenizer,
**dataset_module,
)
if args.do_train:
logger.info_rank0("Start training.")
train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()