!204 fix dpo issue

Merge pull request !204 from 幽若/master-fixdpo
This commit is contained in:
2025-05-10 02:09:31 +00:00
committed by i-robot
parent 2ede23881f
commit 9f25c83026
2 changed files with 31 additions and 30 deletions

View File

@ -19,12 +19,11 @@ import importlib.metadata
import re
import yaml
from typing import Optional
from openmind.utils.constants import Stages, FinetuneType, Frameworks
from openmind.utils.import_utils import is_swanlab_available
from openmind.utils.arguments_utils import str2bool
from openmind.utils import logging, is_transformers_available, is_torch_available, is_trl_available, is_peft_available
from openmind.utils import logging, is_transformers_available, is_torch_available, is_trl_available
from openmind.flow.legacy_arguments import _add_legacy_args, _migrate_legacy_args
if is_torch_available():
@ -34,12 +33,10 @@ if is_torch_available():
else:
from mindformers.trainer.utils import get_last_checkpoint
if is_trl_available():
if is_trl_available() and is_torch_available():
from trl.trainer.dpo_config import DPOConfig
from trl.trainer.reward_config import RewardConfig
if is_peft_available():
from peft import PeftConfig, LoraConfig
logger = logging.get_logger(__name__)
@ -52,28 +49,6 @@ def get_args():
return _GLOBAL_ARGS
def get_peft_config() -> "Optional[PeftConfig]":
args = get_args()
if args.finetuning_type != FinetuneType.LORA:
return None
if not is_peft_available():
raise ValueError(
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
"Make sure to run `pip install -U peft`."
)
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=args.lora_rank,
target_modules=args.lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
use_dora=args.use_dora,
)
return peft_config
def initialize_openmind(yaml_path=None, ignore_unknown_args=False, **kwargs):
args = parse_args(yaml_path, ignore_unknown_args, custom_args=kwargs)
global _GLOBAL_ARGS

View File

@ -18,13 +18,39 @@ from transformers import TrainerCallback
from trl import DPOTrainer
from openmind.utils import get_logger
from openmind.utils import get_logger, is_peft_available
from openmind.utils.constants import FinetuneType
from openmind.flow.model import get_model, get_tokenizer
from openmind.flow.datasets import get_template, get_dataset_module, fix_tokenizer_with_template
from openmind.flow.arguments import get_args, get_peft_config
from openmind.flow.arguments import get_args
logger = get_logger(__name__)
if is_peft_available():
from peft import PeftConfig, LoraConfig
def get_peft_config() -> "Optional[PeftConfig]":
args = get_args()
if args.finetuning_type != FinetuneType.LORA:
return None
if not is_peft_available():
raise ValueError(
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
"Make sure to run `pip install -U peft`."
)
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=args.lora_rank,
target_modules=args.lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
use_dora=args.use_dora,
)
return peft_config
def run_dpo(
callbacks: Optional[List["TrainerCallback"]] = None,
@ -46,7 +72,7 @@ def run_dpo(
logger.info_rank0(f"*******DPO Args: {args.dpo_args} ***********")
# if peft config provided, ref model should be None
if peft_config is not None:
if args.finetuning_type == FinetuneType.LORA:
ref_model = None
else:
ref_model = get_model()