!190 Update datasets loader and pissa algorithm related
Merge pull request !190 from 金勇旭/master
This commit is contained in:
@ -51,7 +51,7 @@ class DatasetFormatConfig:
|
||||
DATASET_FORMAT_REGISTRY: Dict[str, DatasetFormatConfig] = {
|
||||
"alpaca": DatasetFormatConfig(
|
||||
required_columns=["instruction", "output"],
|
||||
allowed_columns=["instruction", "input", "output", "history", "system", "tools"],
|
||||
allowed_columns=["instruction", "input", "output", "history", "system", "tools", "text"],
|
||||
),
|
||||
"sharegpt": DatasetFormatConfig(
|
||||
required_columns=["conversations"], allowed_columns=["conversations", "system", "tools"]
|
||||
|
@ -60,9 +60,6 @@ def apply_lora(
|
||||
if not is_trainable:
|
||||
return model
|
||||
|
||||
logger.info_rank0(f"Fine-tuning method: {'DoRA' if args.use_dora else 'LoRA'}")
|
||||
if args.init_lora_weights:
|
||||
logger.info_rank0(f"Initializing lora weights method: {args.init_lora_weights}")
|
||||
all_target_modules = get_all_lora_target_modules(model)
|
||||
|
||||
if not args.lora_target_modules:
|
||||
@ -77,8 +74,6 @@ def apply_lora(
|
||||
f"The input lora modules {input_target_modules} is not supported. The possible lora modules list is {all_target_modules}"
|
||||
)
|
||||
|
||||
logger.info_rank0(f"Lora target modeules {target_modules} are applied to model")
|
||||
|
||||
peft_kwargs = {
|
||||
"r": args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
@ -98,6 +93,16 @@ def apply_lora(
|
||||
model.enable_input_require_grads()
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
finetuning_method = "DoRA" if model.get_peft_config_as_dict()["use_dora"] else "LoRA"
|
||||
logger.info_rank0(f"Fine-tuning method: {finetuning_method}")
|
||||
|
||||
if args.init_lora_weights:
|
||||
init_lora_weights_method = model.get_peft_config_as_dict()["init_lora_weights"]
|
||||
logger.info_rank0(f"Initializing lora weights method: {init_lora_weights_method}")
|
||||
|
||||
logger.info_rank0(f"Lora target modeules {target_modules} are applied to model")
|
||||
|
||||
model.print_trainable_parameters()
|
||||
# Cast trainable parameters to fp32. According to the source code of PEFT, it has an improvement in stability.
|
||||
|
||||
|
Reference in New Issue
Block a user