@ -19,12 +19,11 @@ import importlib.metadata
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from typing import Optional
|
||||
|
||||
from openmind.utils.constants import Stages, FinetuneType, Frameworks
|
||||
from openmind.utils.import_utils import is_swanlab_available
|
||||
from openmind.utils.arguments_utils import str2bool
|
||||
from openmind.utils import logging, is_transformers_available, is_torch_available, is_trl_available, is_peft_available
|
||||
from openmind.utils import logging, is_transformers_available, is_torch_available, is_trl_available
|
||||
from openmind.flow.legacy_arguments import _add_legacy_args, _migrate_legacy_args
|
||||
|
||||
if is_torch_available():
|
||||
@ -34,12 +33,10 @@ if is_torch_available():
|
||||
else:
|
||||
from mindformers.trainer.utils import get_last_checkpoint
|
||||
|
||||
if is_trl_available():
|
||||
if is_trl_available() and is_torch_available():
|
||||
from trl.trainer.dpo_config import DPOConfig
|
||||
from trl.trainer.reward_config import RewardConfig
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, LoraConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -52,28 +49,6 @@ def get_args():
|
||||
return _GLOBAL_ARGS
|
||||
|
||||
|
||||
def get_peft_config() -> "Optional[PeftConfig]":
|
||||
args = get_args()
|
||||
if args.finetuning_type != FinetuneType.LORA:
|
||||
return None
|
||||
|
||||
if not is_peft_available():
|
||||
raise ValueError(
|
||||
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
|
||||
"Make sure to run `pip install -U peft`."
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
r=args.lora_rank,
|
||||
target_modules=args.lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
)
|
||||
return peft_config
|
||||
|
||||
|
||||
def initialize_openmind(yaml_path=None, ignore_unknown_args=False, **kwargs):
|
||||
args = parse_args(yaml_path, ignore_unknown_args, custom_args=kwargs)
|
||||
global _GLOBAL_ARGS
|
||||
|
@ -18,13 +18,39 @@ from transformers import TrainerCallback
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
from openmind.utils import get_logger
|
||||
from openmind.utils import get_logger, is_peft_available
|
||||
from openmind.utils.constants import FinetuneType
|
||||
from openmind.flow.model import get_model, get_tokenizer
|
||||
from openmind.flow.datasets import get_template, get_dataset_module, fix_tokenizer_with_template
|
||||
from openmind.flow.arguments import get_args, get_peft_config
|
||||
from openmind.flow.arguments import get_args
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, LoraConfig
|
||||
|
||||
|
||||
def get_peft_config() -> "Optional[PeftConfig]":
|
||||
args = get_args()
|
||||
if args.finetuning_type != FinetuneType.LORA:
|
||||
return None
|
||||
|
||||
if not is_peft_available():
|
||||
raise ValueError(
|
||||
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
|
||||
"Make sure to run `pip install -U peft`."
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
r=args.lora_rank,
|
||||
target_modules=args.lora_target_modules,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
)
|
||||
return peft_config
|
||||
|
||||
|
||||
def run_dpo(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
@ -46,7 +72,7 @@ def run_dpo(
|
||||
logger.info_rank0(f"*******DPO Args: {args.dpo_args} ***********")
|
||||
|
||||
# if peft config provided, ref model should be None
|
||||
if peft_config is not None:
|
||||
if args.finetuning_type == FinetuneType.LORA:
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model = get_model()
|
||||
|
Reference in New Issue
Block a user