!202 [feat] add dpo training workflow

Merge pull request !202 from Calvin Huang/dpo
This commit is contained in:
Calvin Huang
2025-05-08 12:16:56 +00:00
committed by i-robot
parent 4ab4b482aa
commit 2945b4969e
15 changed files with 357 additions and 11 deletions

View File

@ -0,0 +1,32 @@
# model
model_name_or_path: Qwen2.5-7B
# method
stage: dpo
do_train: true
finetuning_type: lora
lora_rank: 8
lora_alpha: 16
deepspeed: examples/deepspeed/ds_z2_config.json
# dataset
dataset: dpo_pair
custom_dataset_info: "custom_dataset.json"
template: qwen
cutoff_len: 1024
preprocessing_num_workers: 16
# output
output_dir: saves/qwen2.5-7b-dpo-lora
logging_steps: 1
save_steps: 10
overwrite_output_dir: true
# train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-7
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true

View File

@ -14,7 +14,7 @@
import sys
from openmind.flow.arguments import get_args, initialize_openmind
from openmind.flow.train import run_sft, run_pt
from openmind.flow.train import run_sft, run_pt, run_dpo
from openmind.flow.callbacks import get_swanlab_callbacks
from openmind.utils.constants import Stages
@ -35,6 +35,8 @@ def run_train(**kwargs):
run_sft(callbacks)
elif args.stage == Stages.PT:
run_pt(callbacks)
elif args.stage == Stages.DPO:
run_dpo()
if __name__ == "__main__":

View File

@ -19,11 +19,12 @@ import importlib.metadata
import re
import yaml
from typing import Optional
from openmind.utils.constants import Stages, FinetuneType, Frameworks
from openmind.utils.import_utils import is_swanlab_available
from openmind.utils.arguments_utils import str2bool
from openmind.utils import logging, is_transformers_available, is_torch_available
from openmind.utils import logging, is_transformers_available, is_torch_available, is_trl_available, is_peft_available
from openmind.flow.legacy_arguments import _add_legacy_args, _migrate_legacy_args
if is_torch_available():
@ -33,6 +34,12 @@ if is_torch_available():
else:
from mindformers.trainer.utils import get_last_checkpoint
if is_trl_available():
from trl.trainer.dpo_config import DPOConfig
if is_peft_available():
from peft import PeftConfig, LoraConfig
logger = logging.get_logger(__name__)
_GLOBAL_ARGS = None
@ -44,6 +51,28 @@ def get_args():
return _GLOBAL_ARGS
def get_peft_config() -> "Optional[PeftConfig]":
args = get_args()
if args.finetuning_type != FinetuneType.LORA:
return None
if not is_peft_available():
raise ValueError(
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
"Make sure to run `pip install -U peft`."
)
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=args.lora_rank,
target_modules=args.lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
use_dora=args.use_dora,
)
return peft_config
def initialize_openmind(yaml_path=None, ignore_unknown_args=False, **kwargs):
args = parse_args(yaml_path, ignore_unknown_args, custom_args=kwargs)
global _GLOBAL_ARGS
@ -63,6 +92,9 @@ def parse_args(yaml_path=None, ignore_unknown_args=False, custom_args=None):
parser = _add_eval_args(parser)
parser = _add_legacy_args(parser)
parser = _add_deploy_args(parser)
# dynamically add trl dpo attributes
# need to refactor for conflict resolving
parser = _add_rlhf_args(parser)
unknown_args = None
if custom_args:
@ -87,6 +119,10 @@ def parse_args(yaml_path=None, ignore_unknown_args=False, custom_args=None):
_migrate_legacy_args(parser, vars(args), unknown_args)
validate_args(args)
add_special_args(args)
# add rlhf arguments (ppo/dpo)
add_dpo_args(args)
return args
@ -115,6 +151,15 @@ def add_special_args(args):
setattr(args, "hf_seq2seq_args", seq2seq_args)
def add_dpo_args(args):
# add DPOConfig from trl package
dpo_args = None
if is_trl_available():
hf_parser = HfArgumentParser(DPOConfig)
dpo_args = hf_parser.parse_dict(vars(args), allow_extra_keys=True)[0]
setattr(args, "dpo_args", dpo_args)
def validate_args(args):
"""do sanity check"""
@ -170,8 +215,10 @@ def validate_args(args):
raise ValueError("The version of transformers is required at least 4.45.0 to run quantization.")
# stage and finetune type
if args.stage not in [Stages.SFT, Stages.PT]:
raise ValueError(f"Currently supported stage list is [{Stages.SFT, Stages.PT}]")
valid_stages = [Stages.SFT, Stages.PT, Stages.DPO]
if args.stage not in valid_stages:
raise ValueError(f"Currently supported stage list is {valid_stages}")
if args.finetuning_type not in [FinetuneType.FULL, FinetuneType.LORA]:
raise ValueError(f"Currently supported fine-tuning method list is [{FinetuneType.FULL}, {FinetuneType.LORA}]")
if args.finetuning_type != FinetuneType.LORA and args.use_dora:
@ -302,6 +349,48 @@ def _add_data_args(parser):
return parser
def _add_rlhf_args(parser):
group = parser.add_argument_group(title="rlhf")
group.add_argument(
"--reward_model_path",
type=str,
default=None,
help="Path to the reward model.",
)
group.add_argument(
"--model_adapter_name",
type=str,
default=None,
help="Name of the train target PEFT adapter, when using LoRA with multiple adapters.",
)
group.add_argument(
"--ref_adapter_name",
type=str,
default=None,
help="Name of the reference PEFT adapter, when using LoRA with multiple adapters.",
)
# add trl DPO attribute
if is_trl_available():
from trl.trainer.dpo_config import DPOConfig
import inspect
existing_args = {action.dest for action in parser._actions}
signature = inspect.signature(DPOConfig.__init__)
for param_name, param in signature.parameters.items():
if param_name in ("self", "kwargs") or param_name in existing_args:
continue
default = param.default if param.default is not inspect.Parameter.empty else None
param_type = type(default) if default is not None else str
group.add_argument(f"--{param_name}", type=param_type, default=default, help=f"DPO parameter: {param_name}")
return parser
def _add_model_args(parser):
group = parser.add_argument_group(title="model")

View File

@ -18,13 +18,13 @@ from openmind.utils import _LazyModule
if TYPE_CHECKING:
from .loader import get_dataset_module
from .template import get_template
from .template import get_template, fix_tokenizer_with_template
else:
import sys
_import_structure = {
"loader": ["get_dataset_module"],
"template": ["get_template"],
"template": ["get_template", "fix_tokenizer_with_template"],
}
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -26,6 +26,7 @@ from openmind.flow.datasets.preprocess import (
merge_datasets,
preprocess_supervised_dataset,
preprocess_pretrain_dataset,
preprocess_pairwise_dataset,
)
from openmind.flow.arguments import get_args
from openmind.flow.datasets.template import Template
@ -57,6 +58,10 @@ DATASET_FORMAT_REGISTRY: Dict[str, DatasetFormatConfig] = {
required_columns=["conversations"], allowed_columns=["conversations", "system", "tools"]
),
"text": DatasetFormatConfig(required_columns=["text"], allowed_columns=["text"]),
"pairwise": DatasetFormatConfig(
required_columns=["chosen", "rejected"],
allowed_columns=["prompt", "chosen", "rejected", "response", "system", "tools"],
),
}
@ -159,6 +164,7 @@ def _get_preprocessed_dataset(
desc="Start running tokenizer on datasets",
)
preprocess_func = _get_preprocess_func(template, tokenizer, processor)
logger.info_rank0(f"\n******removed columes: {column_names} *********\n")
dataset = dataset.map(
preprocess_func,
batched=True,
@ -166,9 +172,11 @@ def _get_preprocessed_dataset(
remove_columns=column_names,
**preprocess_kwargs,
)
logger.info_rank0(f"\n******processed new columes: {dataset.column_names} *********\n")
# print datasets example applied template
logger.info_rank0("\ninput:\n{}".format(tokenizer.decode(dataset["input_ids"][0])))
logger.info_rank0("\ninput_ids:\n{}\n".format(dataset["input_ids"][0]))
if args.stage in [Stages.SFT, Stages.PT]:
logger.info_rank0("\ninput:\n{}".format(tokenizer.decode(dataset["input_ids"][0])))
logger.info_rank0("\ninput_ids:\n{}\n".format(dataset["input_ids"][0]))
return dataset
@ -181,6 +189,10 @@ def _get_preprocess_func(template, tokenizer, processor):
preprocess_func = partial(
preprocess_supervised_dataset, template=template, tokenizer=tokenizer, processor=processor
)
elif args.stage == Stages.DPO:
preprocess_func = partial(
preprocess_pairwise_dataset, template=template, tokenizer=tokenizer, cutoff_len=args.cutoff_len
)
else:
raise NotImplementedError
return preprocess_func

View File

@ -28,7 +28,7 @@ class InstructionDatasetAttr:
name: Optional[str] = None
load_from: Optional[str] = "om_hub"
file_name: Optional[str] = None
formatting: Literal["alpaca", "sharegpt", "text"] = "alpaca"
formatting: Literal["alpaca", "sharegpt", "pairwise", "text"] = "alpaca"
ranking: bool = False
is_custom = False
# extra configs
@ -50,6 +50,9 @@ class InstructionDatasetAttr:
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = "history"
# pairwise colunms
chosen: Optional[str] = "chosen"
rejected: Optional[str] = "rejected"
# sharegpt columns
messages: Optional[str] = "conversations"
# sharegpt tags
@ -113,6 +116,8 @@ def get_dataset_attr(dataset: Optional[str], dataset_info) -> "InstructionDatase
column_names.extend(["prompt", "query", "response", "history"])
elif dataset_attr.formatting == "text":
column_names.extend(["text_column"])
elif dataset_attr.formatting == "pairwise":
column_names.extend(["prompt", "query", "response", "chosen", "rejected"])
else:
column_names.extend(["messages"])

View File

@ -19,6 +19,7 @@ from openmind.utils import _LazyModule
if TYPE_CHECKING:
from .sft import preprocess_supervised_dataset
from .pt import preprocess_pretrain_dataset
from .dpo import preprocess_pairwise_dataset
from .generic import align_dataset, merge_datasets
else:
import sys
@ -26,6 +27,7 @@ else:
_import_structure = {
"sft": ["preprocess_supervised_dataset"],
"pt": ["preprocess_pretrain_dataset"],
"dpo": ["preprocess_pairwise_dataset"],
"generic": ["align_dataset", "merge_datasets"],
}

View File

@ -0,0 +1,41 @@
# Copyright 2024 the LlamaFactory team.
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
#
# This code is inspired by the LLaMA-Factory.
# https://github.com/hiyouga/LLaMA-Factory/blob/main/src/src/llamafactory/data/processors/supervised.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict, Any
from collections import defaultdict
from openmind.archived.models.auto import AutoTokenizer
from openmind.utils import get_logger
from openmind.flow.datasets.template import Template
logger = get_logger(__name__) # pylint: disable=invalid-name
def preprocess_pairwise_dataset(
examples, template: Template, tokenizer: AutoTokenizer, cutoff_len: int
) -> Dict[str, List[Any]]:
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue
model_inputs["prompt"].append(examples["_prompt"][i])
model_inputs["chosen"].append([examples["_response"][i][0][0]])
model_inputs["rejected"].append([examples["_response"][i][0][1]])
return model_inputs

View File

@ -38,6 +38,43 @@ class ConversionOutput(TypedDict):
audios: List
def convert_pairwise(examples, datasets_attr: InstructionDatasetAttr):
"""
Convert the dataset to alpaca format.
Args:
examples: examples of datasets
datasets_attr: The attributes of datasets.
Returns:
Out
"""
outputs: ConversionOutput = {"_prompt": [], "_response": []}
for i in range(len(examples[datasets_attr.prompt])):
prompt = []
content = []
response = []
if examples[datasets_attr.prompt][i]:
content.append(examples[datasets_attr.prompt][i])
prompt.append({"role": "user", "content": "\n".join(content)})
if examples[datasets_attr.chosen][i] and examples[datasets_attr.rejected][i]:
# response.append([examples[datasets_attr.chosen][i], examples[datasets_attr.rejected][i]])
response.append(
[
{"role": "assistant", "content": examples[datasets_attr.chosen][i]},
{"role": "assistant", "content": examples[datasets_attr.rejected][i]},
]
)
outputs["_prompt"].append(prompt)
outputs["_response"].append(response)
return outputs
def convert_alpaca(examples, datasets_attr: InstructionDatasetAttr, convert_system=False, convert_tools=False):
"""
Convert the dataset to alpaca format.
@ -80,6 +117,7 @@ def convert_alpaca(examples, datasets_attr: InstructionDatasetAttr, convert_syst
content.append(examples[datasets_attr.query][i])
prompt.append({"role": "user", "content": "\n".join(content)})
if isinstance(examples[datasets_attr.response][i], str):
response = [{"role": "assistant", "content": examples[datasets_attr.response][i]}]
else: # unsupervised
@ -251,10 +289,15 @@ def align_dataset(
convert_system=True if "system" in dataset.column_names else False,
convert_tools=True if "tools" in dataset.column_names else False,
)
elif dataset_attr.formatting == "pairwise":
convert_func = partial(
convert_pairwise,
datasets_attr=dataset_attr,
)
elif dataset_attr.formatting == "text":
convert_func = partial(convert_text, text_column=dataset_attr.text_column)
else:
raise ValueError("Currently, Dataset formats only support alpaca, sharegpt, text.")
raise ValueError("Currently, Dataset formats only support alpaca, sharegpt, pairwise, text.")
# The following code is consistent with the format of datasets in llama factory.
column_names = list(next(iter(dataset)).keys())

View File

@ -29,6 +29,8 @@ from openmind.utils import get_logger
from openmind.flow.arguments import get_args
from openmind.flow.datasets.mm_plugin import BasePlugin, parse_mm_plugin
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
# {"qwen": openmind.flow.datasets.template.Template object}
@ -104,16 +106,31 @@ class Template:
return encoded_pairs
@staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""Add or replace eos token to the tokenizer."""
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
else:
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def encode_oneturn(
self,
tokenizer,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
):
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_messages = self.encode(tokenizer, messages, system)
encoded_messages = self.encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
@ -153,6 +170,17 @@ class Template:
return self._make_pairs(encoded_messages, args.cutoff_len, args.reserved_label_len)
def fix_special_tokens(self, tokenizer: PreTrainedTokenizer) -> None:
r"""
Add eos token and pad token to the tokenizer.
"""
if tokenizer.eos_token_id is None:
self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
def str_to_dict_for_mm_plugin(name):
EMPTY = {"plugin_name": "base"}
@ -264,3 +292,10 @@ def get_template():
logger.info_rank0(f"Apply template {template_type}")
return template
def fix_tokenizer_with_template(tokenizer: PreTrainedTokenizer, template: Template) -> None:
"""
Fix tokenizer with chat template
"""
template.fix_special_tokens(tokenizer)

View File

@ -1,2 +1,3 @@
from .sft import run_sft
from .pt import run_pt
from .dpo import run_dpo

View File

@ -0,0 +1 @@
from .workflow import run_dpo

View File

@ -0,0 +1,69 @@
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
#
# openMind is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
from typing import List, Optional
from transformers import TrainerCallback
from trl import DPOTrainer
from openmind.utils import get_logger
from openmind.flow.model import get_model, get_tokenizer
from openmind.flow.datasets import get_template, get_dataset_module, fix_tokenizer_with_template
from openmind.flow.arguments import get_args, get_peft_config
logger = get_logger(__name__)
def run_dpo(
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = get_tokenizer()
template = get_template()
fix_tokenizer_with_template(tokenizer, template)
dataset_module = get_dataset_module(tokenizer, template)
args = get_args()
peft_config = get_peft_config()
model = get_model()
logger.info_rank0(f"*******DPO Args: {args.dpo_args} ***********")
# if peft config provided, ref model should be None
if peft_config is not None:
ref_model = None
else:
ref_model = get_model()
trainer = DPOTrainer(
args=args.dpo_args,
processing_class=tokenizer,
model=model,
ref_model=ref_model,
peft_config=peft_config,
**dataset_module,
)
if args.do_train:
logger.info_rank0("Start DPO training.")
train_result = trainer.train()
trainer.save_model(args.output_dir)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

View File

@ -18,6 +18,8 @@ __all__ = [
"is_torch_npu_available",
"is_mindformers_available",
"is_transformers_available",
"is_trl_available",
"is_peft_available",
"is_diffusers_available",
"is_mindone_available",
"is_mindnlp_available",
@ -38,10 +40,12 @@ from .import_utils import (
is_torch_npu_available,
is_ms_available,
is_transformers_available,
is_trl_available,
is_mindformers_available,
is_diffusers_available,
is_mindone_available,
is_mindnlp_available,
is_peft_available,
is_sentencepiece_available,
is_timm_available,
is_vision_available,

View File

@ -137,6 +137,16 @@ def is_transformers_available():
return _is_package_available("transformers")
@lru_cache
def is_trl_available():
return _is_package_available("trl")
@lru_cache
def is_peft_available():
return _is_package_available("peft")
@lru_cache
def is_mindformers_available():
return _is_package_available("mindformers")