!202 [feat] add dpo training workflow
Merge pull request !202 from Calvin Huang/dpo
This commit is contained in:
32
examples/features/train/train_dpo_lora.yaml
Normal file
32
examples/features/train/train_dpo_lora.yaml
Normal file
@ -0,0 +1,32 @@
|
||||
# model
|
||||
model_name_or_path: Qwen2.5-7B
|
||||
|
||||
# method
|
||||
stage: dpo
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_alpha: 16
|
||||
deepspeed: examples/deepspeed/ds_z2_config.json
|
||||
|
||||
# dataset
|
||||
dataset: dpo_pair
|
||||
custom_dataset_info: "custom_dataset.json"
|
||||
template: qwen
|
||||
cutoff_len: 1024
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
output_dir: saves/qwen2.5-7b-dpo-lora
|
||||
logging_steps: 1
|
||||
save_steps: 10
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 5.0e-7
|
||||
num_train_epochs: 1.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
@ -14,7 +14,7 @@
|
||||
import sys
|
||||
|
||||
from openmind.flow.arguments import get_args, initialize_openmind
|
||||
from openmind.flow.train import run_sft, run_pt
|
||||
from openmind.flow.train import run_sft, run_pt, run_dpo
|
||||
from openmind.flow.callbacks import get_swanlab_callbacks
|
||||
from openmind.utils.constants import Stages
|
||||
|
||||
@ -35,6 +35,8 @@ def run_train(**kwargs):
|
||||
run_sft(callbacks)
|
||||
elif args.stage == Stages.PT:
|
||||
run_pt(callbacks)
|
||||
elif args.stage == Stages.DPO:
|
||||
run_dpo()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -19,11 +19,12 @@ import importlib.metadata
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from typing import Optional
|
||||
|
||||
from openmind.utils.constants import Stages, FinetuneType, Frameworks
|
||||
from openmind.utils.import_utils import is_swanlab_available
|
||||
from openmind.utils.arguments_utils import str2bool
|
||||
from openmind.utils import logging, is_transformers_available, is_torch_available
|
||||
from openmind.utils import logging, is_transformers_available, is_torch_available, is_trl_available, is_peft_available
|
||||
from openmind.flow.legacy_arguments import _add_legacy_args, _migrate_legacy_args
|
||||
|
||||
if is_torch_available():
|
||||
@ -33,6 +34,12 @@ if is_torch_available():
|
||||
else:
|
||||
from mindformers.trainer.utils import get_last_checkpoint
|
||||
|
||||
if is_trl_available():
|
||||
from trl.trainer.dpo_config import DPOConfig
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, LoraConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_GLOBAL_ARGS = None
|
||||
@ -44,6 +51,28 @@ def get_args():
|
||||
return _GLOBAL_ARGS
|
||||
|
||||
|
||||
def get_peft_config() -> "Optional[PeftConfig]":
|
||||
args = get_args()
|
||||
if args.finetuning_type != FinetuneType.LORA:
|
||||
return None
|
||||
|
||||
if not is_peft_available():
|
||||
raise ValueError(
|
||||
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
|
||||
"Make sure to run `pip install -U peft`."
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
r=args.lora_rank,
|
||||
target_modules=args.lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
)
|
||||
return peft_config
|
||||
|
||||
|
||||
def initialize_openmind(yaml_path=None, ignore_unknown_args=False, **kwargs):
|
||||
args = parse_args(yaml_path, ignore_unknown_args, custom_args=kwargs)
|
||||
global _GLOBAL_ARGS
|
||||
@ -63,6 +92,9 @@ def parse_args(yaml_path=None, ignore_unknown_args=False, custom_args=None):
|
||||
parser = _add_eval_args(parser)
|
||||
parser = _add_legacy_args(parser)
|
||||
parser = _add_deploy_args(parser)
|
||||
# dynamically add trl dpo attributes
|
||||
# need to refactor for conflict resolving
|
||||
parser = _add_rlhf_args(parser)
|
||||
|
||||
unknown_args = None
|
||||
if custom_args:
|
||||
@ -87,6 +119,10 @@ def parse_args(yaml_path=None, ignore_unknown_args=False, custom_args=None):
|
||||
_migrate_legacy_args(parser, vars(args), unknown_args)
|
||||
validate_args(args)
|
||||
add_special_args(args)
|
||||
|
||||
# add rlhf arguments (ppo/dpo)
|
||||
add_dpo_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@ -115,6 +151,15 @@ def add_special_args(args):
|
||||
setattr(args, "hf_seq2seq_args", seq2seq_args)
|
||||
|
||||
|
||||
def add_dpo_args(args):
|
||||
# add DPOConfig from trl package
|
||||
dpo_args = None
|
||||
if is_trl_available():
|
||||
hf_parser = HfArgumentParser(DPOConfig)
|
||||
dpo_args = hf_parser.parse_dict(vars(args), allow_extra_keys=True)[0]
|
||||
setattr(args, "dpo_args", dpo_args)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""do sanity check"""
|
||||
|
||||
@ -170,8 +215,10 @@ def validate_args(args):
|
||||
raise ValueError("The version of transformers is required at least 4.45.0 to run quantization.")
|
||||
|
||||
# stage and finetune type
|
||||
if args.stage not in [Stages.SFT, Stages.PT]:
|
||||
raise ValueError(f"Currently supported stage list is [{Stages.SFT, Stages.PT}]")
|
||||
valid_stages = [Stages.SFT, Stages.PT, Stages.DPO]
|
||||
if args.stage not in valid_stages:
|
||||
raise ValueError(f"Currently supported stage list is {valid_stages}")
|
||||
|
||||
if args.finetuning_type not in [FinetuneType.FULL, FinetuneType.LORA]:
|
||||
raise ValueError(f"Currently supported fine-tuning method list is [{FinetuneType.FULL}, {FinetuneType.LORA}]")
|
||||
if args.finetuning_type != FinetuneType.LORA and args.use_dora:
|
||||
@ -302,6 +349,48 @@ def _add_data_args(parser):
|
||||
return parser
|
||||
|
||||
|
||||
def _add_rlhf_args(parser):
|
||||
group = parser.add_argument_group(title="rlhf")
|
||||
|
||||
group.add_argument(
|
||||
"--reward_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the reward model.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--model_adapter_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the train target PEFT adapter, when using LoRA with multiple adapters.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--ref_adapter_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the reference PEFT adapter, when using LoRA with multiple adapters.",
|
||||
)
|
||||
|
||||
# add trl DPO attribute
|
||||
if is_trl_available():
|
||||
from trl.trainer.dpo_config import DPOConfig
|
||||
import inspect
|
||||
|
||||
existing_args = {action.dest for action in parser._actions}
|
||||
|
||||
signature = inspect.signature(DPOConfig.__init__)
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param_name in ("self", "kwargs") or param_name in existing_args:
|
||||
continue
|
||||
default = param.default if param.default is not inspect.Parameter.empty else None
|
||||
param_type = type(default) if default is not None else str
|
||||
group.add_argument(f"--{param_name}", type=param_type, default=default, help=f"DPO parameter: {param_name}")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _add_model_args(parser):
|
||||
group = parser.add_argument_group(title="model")
|
||||
|
||||
|
@ -18,13 +18,13 @@ from openmind.utils import _LazyModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loader import get_dataset_module
|
||||
from .template import get_template
|
||||
from .template import get_template, fix_tokenizer_with_template
|
||||
else:
|
||||
import sys
|
||||
|
||||
_import_structure = {
|
||||
"loader": ["get_dataset_module"],
|
||||
"template": ["get_template"],
|
||||
"template": ["get_template", "fix_tokenizer_with_template"],
|
||||
}
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
@ -26,6 +26,7 @@ from openmind.flow.datasets.preprocess import (
|
||||
merge_datasets,
|
||||
preprocess_supervised_dataset,
|
||||
preprocess_pretrain_dataset,
|
||||
preprocess_pairwise_dataset,
|
||||
)
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.datasets.template import Template
|
||||
@ -57,6 +58,10 @@ DATASET_FORMAT_REGISTRY: Dict[str, DatasetFormatConfig] = {
|
||||
required_columns=["conversations"], allowed_columns=["conversations", "system", "tools"]
|
||||
),
|
||||
"text": DatasetFormatConfig(required_columns=["text"], allowed_columns=["text"]),
|
||||
"pairwise": DatasetFormatConfig(
|
||||
required_columns=["chosen", "rejected"],
|
||||
allowed_columns=["prompt", "chosen", "rejected", "response", "system", "tools"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -159,6 +164,7 @@ def _get_preprocessed_dataset(
|
||||
desc="Start running tokenizer on datasets",
|
||||
)
|
||||
preprocess_func = _get_preprocess_func(template, tokenizer, processor)
|
||||
logger.info_rank0(f"\n******removed columes: {column_names} *********\n")
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
batched=True,
|
||||
@ -166,9 +172,11 @@ def _get_preprocessed_dataset(
|
||||
remove_columns=column_names,
|
||||
**preprocess_kwargs,
|
||||
)
|
||||
logger.info_rank0(f"\n******processed new columes: {dataset.column_names} *********\n")
|
||||
# print datasets example applied template
|
||||
logger.info_rank0("\ninput:\n{}".format(tokenizer.decode(dataset["input_ids"][0])))
|
||||
logger.info_rank0("\ninput_ids:\n{}\n".format(dataset["input_ids"][0]))
|
||||
if args.stage in [Stages.SFT, Stages.PT]:
|
||||
logger.info_rank0("\ninput:\n{}".format(tokenizer.decode(dataset["input_ids"][0])))
|
||||
logger.info_rank0("\ninput_ids:\n{}\n".format(dataset["input_ids"][0]))
|
||||
return dataset
|
||||
|
||||
|
||||
@ -181,6 +189,10 @@ def _get_preprocess_func(template, tokenizer, processor):
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset, template=template, tokenizer=tokenizer, processor=processor
|
||||
)
|
||||
elif args.stage == Stages.DPO:
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset, template=template, tokenizer=tokenizer, cutoff_len=args.cutoff_len
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return preprocess_func
|
||||
|
@ -28,7 +28,7 @@ class InstructionDatasetAttr:
|
||||
name: Optional[str] = None
|
||||
load_from: Optional[str] = "om_hub"
|
||||
file_name: Optional[str] = None
|
||||
formatting: Literal["alpaca", "sharegpt", "text"] = "alpaca"
|
||||
formatting: Literal["alpaca", "sharegpt", "pairwise", "text"] = "alpaca"
|
||||
ranking: bool = False
|
||||
is_custom = False
|
||||
# extra configs
|
||||
@ -50,6 +50,9 @@ class InstructionDatasetAttr:
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = "history"
|
||||
# pairwise colunms
|
||||
chosen: Optional[str] = "chosen"
|
||||
rejected: Optional[str] = "rejected"
|
||||
# sharegpt columns
|
||||
messages: Optional[str] = "conversations"
|
||||
# sharegpt tags
|
||||
@ -113,6 +116,8 @@ def get_dataset_attr(dataset: Optional[str], dataset_info) -> "InstructionDatase
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
elif dataset_attr.formatting == "text":
|
||||
column_names.extend(["text_column"])
|
||||
elif dataset_attr.formatting == "pairwise":
|
||||
column_names.extend(["prompt", "query", "response", "chosen", "rejected"])
|
||||
else:
|
||||
column_names.extend(["messages"])
|
||||
|
||||
|
@ -19,6 +19,7 @@ from openmind.utils import _LazyModule
|
||||
if TYPE_CHECKING:
|
||||
from .sft import preprocess_supervised_dataset
|
||||
from .pt import preprocess_pretrain_dataset
|
||||
from .dpo import preprocess_pairwise_dataset
|
||||
from .generic import align_dataset, merge_datasets
|
||||
else:
|
||||
import sys
|
||||
@ -26,6 +27,7 @@ else:
|
||||
_import_structure = {
|
||||
"sft": ["preprocess_supervised_dataset"],
|
||||
"pt": ["preprocess_pretrain_dataset"],
|
||||
"dpo": ["preprocess_pairwise_dataset"],
|
||||
"generic": ["align_dataset", "merge_datasets"],
|
||||
}
|
||||
|
||||
|
41
src/openmind/flow/datasets/preprocess/dpo.py
Normal file
41
src/openmind/flow/datasets/preprocess/dpo.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# This code is inspired by the LLaMA-Factory.
|
||||
# https://github.com/hiyouga/LLaMA-Factory/blob/main/src/src/llamafactory/data/processors/supervised.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from openmind.archived.models.auto import AutoTokenizer
|
||||
from openmind.utils import get_logger
|
||||
from openmind.flow.datasets.template import Template
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples, template: Template, tokenizer: AutoTokenizer, cutoff_len: int
|
||||
) -> Dict[str, List[Any]]:
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
continue
|
||||
model_inputs["prompt"].append(examples["_prompt"][i])
|
||||
model_inputs["chosen"].append([examples["_response"][i][0][0]])
|
||||
model_inputs["rejected"].append([examples["_response"][i][0][1]])
|
||||
|
||||
return model_inputs
|
@ -38,6 +38,43 @@ class ConversionOutput(TypedDict):
|
||||
audios: List
|
||||
|
||||
|
||||
def convert_pairwise(examples, datasets_attr: InstructionDatasetAttr):
|
||||
"""
|
||||
Convert the dataset to alpaca format.
|
||||
Args:
|
||||
examples: examples of datasets
|
||||
datasets_attr: The attributes of datasets.
|
||||
|
||||
Returns:
|
||||
Out
|
||||
"""
|
||||
outputs: ConversionOutput = {"_prompt": [], "_response": []}
|
||||
|
||||
for i in range(len(examples[datasets_attr.prompt])):
|
||||
prompt = []
|
||||
content = []
|
||||
response = []
|
||||
|
||||
if examples[datasets_attr.prompt][i]:
|
||||
content.append(examples[datasets_attr.prompt][i])
|
||||
|
||||
prompt.append({"role": "user", "content": "\n".join(content)})
|
||||
|
||||
if examples[datasets_attr.chosen][i] and examples[datasets_attr.rejected][i]:
|
||||
# response.append([examples[datasets_attr.chosen][i], examples[datasets_attr.rejected][i]])
|
||||
response.append(
|
||||
[
|
||||
{"role": "assistant", "content": examples[datasets_attr.chosen][i]},
|
||||
{"role": "assistant", "content": examples[datasets_attr.rejected][i]},
|
||||
]
|
||||
)
|
||||
|
||||
outputs["_prompt"].append(prompt)
|
||||
outputs["_response"].append(response)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_alpaca(examples, datasets_attr: InstructionDatasetAttr, convert_system=False, convert_tools=False):
|
||||
"""
|
||||
Convert the dataset to alpaca format.
|
||||
@ -80,6 +117,7 @@ def convert_alpaca(examples, datasets_attr: InstructionDatasetAttr, convert_syst
|
||||
content.append(examples[datasets_attr.query][i])
|
||||
|
||||
prompt.append({"role": "user", "content": "\n".join(content)})
|
||||
|
||||
if isinstance(examples[datasets_attr.response][i], str):
|
||||
response = [{"role": "assistant", "content": examples[datasets_attr.response][i]}]
|
||||
else: # unsupervised
|
||||
@ -251,10 +289,15 @@ def align_dataset(
|
||||
convert_system=True if "system" in dataset.column_names else False,
|
||||
convert_tools=True if "tools" in dataset.column_names else False,
|
||||
)
|
||||
elif dataset_attr.formatting == "pairwise":
|
||||
convert_func = partial(
|
||||
convert_pairwise,
|
||||
datasets_attr=dataset_attr,
|
||||
)
|
||||
elif dataset_attr.formatting == "text":
|
||||
convert_func = partial(convert_text, text_column=dataset_attr.text_column)
|
||||
else:
|
||||
raise ValueError("Currently, Dataset formats only support alpaca, sharegpt, text.")
|
||||
raise ValueError("Currently, Dataset formats only support alpaca, sharegpt, pairwise, text.")
|
||||
|
||||
# The following code is consistent with the format of datasets in llama factory.
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
|
@ -29,6 +29,8 @@ from openmind.utils import get_logger
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.datasets.mm_plugin import BasePlugin, parse_mm_plugin
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# {"qwen": openmind.flow.datasets.template.Template object}
|
||||
@ -104,16 +106,31 @@ class Template:
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
@staticmethod
|
||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||
r"""Add or replace eos token to the tokenizer."""
|
||||
is_added = tokenizer.eos_token_id is None
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
if is_added:
|
||||
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
|
||||
else:
|
||||
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
|
||||
|
||||
if num_added_tokens > 0:
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
):
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_messages = self.encode(tokenizer, messages, system)
|
||||
encoded_messages = self.encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
@ -153,6 +170,17 @@ class Template:
|
||||
|
||||
return self._make_pairs(encoded_messages, args.cutoff_len, args.reserved_label_len)
|
||||
|
||||
def fix_special_tokens(self, tokenizer: PreTrainedTokenizer) -> None:
|
||||
r"""
|
||||
Add eos token and pad token to the tokenizer.
|
||||
"""
|
||||
if tokenizer.eos_token_id is None:
|
||||
self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||
|
||||
|
||||
def str_to_dict_for_mm_plugin(name):
|
||||
EMPTY = {"plugin_name": "base"}
|
||||
@ -264,3 +292,10 @@ def get_template():
|
||||
logger.info_rank0(f"Apply template {template_type}")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def fix_tokenizer_with_template(tokenizer: PreTrainedTokenizer, template: Template) -> None:
|
||||
"""
|
||||
Fix tokenizer with chat template
|
||||
"""
|
||||
template.fix_special_tokens(tokenizer)
|
||||
|
@ -1,2 +1,3 @@
|
||||
from .sft import run_sft
|
||||
from .pt import run_pt
|
||||
from .dpo import run_dpo
|
||||
|
1
src/openmind/flow/train/dpo/__init__.py
Normal file
1
src/openmind/flow/train/dpo/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .workflow import run_dpo
|
69
src/openmind/flow/train/dpo/workflow.py
Normal file
69
src/openmind/flow/train/dpo/workflow.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
from openmind.utils import get_logger
|
||||
from openmind.flow.model import get_model, get_tokenizer
|
||||
from openmind.flow.datasets import get_template, get_dataset_module, fix_tokenizer_with_template
|
||||
from openmind.flow.arguments import get_args, get_peft_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_dpo(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
template = get_template()
|
||||
|
||||
fix_tokenizer_with_template(tokenizer, template)
|
||||
|
||||
dataset_module = get_dataset_module(tokenizer, template)
|
||||
|
||||
args = get_args()
|
||||
|
||||
peft_config = get_peft_config()
|
||||
|
||||
model = get_model()
|
||||
|
||||
logger.info_rank0(f"*******DPO Args: {args.dpo_args} ***********")
|
||||
|
||||
# if peft config provided, ref model should be None
|
||||
if peft_config is not None:
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model = get_model()
|
||||
|
||||
trainer = DPOTrainer(
|
||||
args=args.dpo_args,
|
||||
processing_class=tokenizer,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
peft_config=peft_config,
|
||||
**dataset_module,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
logger.info_rank0("Start DPO training.")
|
||||
train_result = trainer.train()
|
||||
trainer.save_model(args.output_dir)
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
@ -18,6 +18,8 @@ __all__ = [
|
||||
"is_torch_npu_available",
|
||||
"is_mindformers_available",
|
||||
"is_transformers_available",
|
||||
"is_trl_available",
|
||||
"is_peft_available",
|
||||
"is_diffusers_available",
|
||||
"is_mindone_available",
|
||||
"is_mindnlp_available",
|
||||
@ -38,10 +40,12 @@ from .import_utils import (
|
||||
is_torch_npu_available,
|
||||
is_ms_available,
|
||||
is_transformers_available,
|
||||
is_trl_available,
|
||||
is_mindformers_available,
|
||||
is_diffusers_available,
|
||||
is_mindone_available,
|
||||
is_mindnlp_available,
|
||||
is_peft_available,
|
||||
is_sentencepiece_available,
|
||||
is_timm_available,
|
||||
is_vision_available,
|
||||
|
@ -137,6 +137,16 @@ def is_transformers_available():
|
||||
return _is_package_available("transformers")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_trl_available():
|
||||
return _is_package_available("trl")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_peft_available():
|
||||
return _is_package_available("peft")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_mindformers_available():
|
||||
return _is_package_available("mindformers")
|
||||
|
Reference in New Issue
Block a user