mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 23:43:47 +08:00
515 lines
22 KiB
Python
515 lines
22 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from collections import Counter
|
|
from dataclasses import dataclass
|
|
from operator import attrgetter
|
|
from typing import Optional, Union
|
|
|
|
import safetensors
|
|
import torch
|
|
import torch.nn as nn
|
|
from diffusers import UNet2DConditionModel
|
|
from transformers import CLIPTextModel
|
|
|
|
from peft import LoHaConfig, LoKrConfig, LoraConfig, PeftType, get_peft_model, set_peft_model_state_dict
|
|
from peft.tuners.lokr.layer import factorization
|
|
|
|
|
|
# Default kohya_ss LoRA replacement modules
|
|
# https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L661
|
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
|
PREFIX_UNET = "lora_unet"
|
|
PREFIX_TEXT_ENCODER = "lora_te"
|
|
|
|
|
|
@dataclass
|
|
class LoRAInfo:
|
|
kohya_key: str
|
|
peft_key: str
|
|
alpha: Optional[float] = None
|
|
rank: Optional[int] = None
|
|
lora_A: Optional[torch.Tensor] = None
|
|
lora_B: Optional[torch.Tensor] = None
|
|
|
|
def peft_state_dict(self) -> dict[str, torch.Tensor]:
|
|
if self.lora_A is None or self.lora_B is None:
|
|
raise ValueError("At least one of lora_A or lora_B is None, they must both be provided")
|
|
return {
|
|
f"base_model.model.{self.peft_key}.lora_A.weight": self.lora_A,
|
|
f"base_model.model.{self.peft_key}.lora_B.weight": self.lora_B,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class LoHaInfo:
|
|
kohya_key: str
|
|
peft_key: str
|
|
alpha: Optional[float] = None
|
|
rank: Optional[int] = None
|
|
hada_w1_a: Optional[torch.Tensor] = None
|
|
hada_w1_b: Optional[torch.Tensor] = None
|
|
hada_w2_a: Optional[torch.Tensor] = None
|
|
hada_w2_b: Optional[torch.Tensor] = None
|
|
hada_t1: Optional[torch.Tensor] = None
|
|
hada_t2: Optional[torch.Tensor] = None
|
|
|
|
def peft_state_dict(self) -> dict[str, torch.Tensor]:
|
|
if self.hada_w1_a is None or self.hada_w1_b is None or self.hada_w2_a is None or self.hada_w2_b is None:
|
|
raise ValueError(
|
|
"At least one of hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b is missing, they all must be provided"
|
|
)
|
|
state_dict = {
|
|
f"base_model.model.{self.peft_key}.hada_w1_a": self.hada_w1_a,
|
|
f"base_model.model.{self.peft_key}.hada_w1_b": self.hada_w1_b,
|
|
f"base_model.model.{self.peft_key}.hada_w2_a": self.hada_w2_a,
|
|
f"base_model.model.{self.peft_key}.hada_w2_b": self.hada_w2_b,
|
|
}
|
|
if not (
|
|
(self.hada_t1 is None and self.hada_t2 is None) or (self.hada_t1 is not None and self.hada_t2 is not None)
|
|
):
|
|
raise ValueError("hada_t1 and hada_t2 must be either both present or not present at the same time")
|
|
if self.hada_t1 is not None and self.hada_t2 is not None:
|
|
state_dict[f"base_model.model.{self.peft_key}.hada_t1"] = self.hada_t1
|
|
state_dict[f"base_model.model.{self.peft_key}.hada_t2"] = self.hada_t2
|
|
return state_dict
|
|
|
|
|
|
@dataclass
|
|
class LoKrInfo:
|
|
kohya_key: str
|
|
peft_key: str
|
|
alpha: Optional[float] = None
|
|
rank: Optional[int] = None
|
|
lokr_w1: Optional[torch.Tensor] = None
|
|
lokr_w1_a: Optional[torch.Tensor] = None
|
|
lokr_w1_b: Optional[torch.Tensor] = None
|
|
lokr_w2: Optional[torch.Tensor] = None
|
|
lokr_w2_a: Optional[torch.Tensor] = None
|
|
lokr_w2_b: Optional[torch.Tensor] = None
|
|
lokr_t2: Optional[torch.Tensor] = None
|
|
|
|
def peft_state_dict(self) -> dict[str, torch.Tensor]:
|
|
if (self.lokr_w1 is None) and ((self.lokr_w1_a is None) or (self.lokr_w1_b is None)):
|
|
raise ValueError("Either lokr_w1 or both lokr_w1_a and lokr_w1_b should be provided")
|
|
|
|
if (self.lokr_w2 is None) and ((self.lokr_w2_a is None) or (self.lokr_w2_b is None)):
|
|
raise ValueError("Either lokr_w2 or both lokr_w2_a and lokr_w2_b should be provided")
|
|
|
|
state_dict = {}
|
|
|
|
if self.lokr_w1 is not None:
|
|
state_dict[f"base_model.model.{self.peft_key}.lokr_w1"] = self.lokr_w1
|
|
elif self.lokr_w1_a is not None:
|
|
state_dict[f"base_model.model.{self.peft_key}.lokr_w1_a"] = self.lokr_w1_a
|
|
state_dict[f"base_model.model.{self.peft_key}.lokr_w1_b"] = self.lokr_w1_b
|
|
|
|
if self.lokr_w2 is not None:
|
|
state_dict[f"base_model.model.{self.peft_key}.lokr_w2"] = self.lokr_w2
|
|
elif self.lokr_w2_a is not None:
|
|
state_dict[f"base_model.model.{self.peft_key}.lokr_w2_a"] = self.lokr_w2_a
|
|
state_dict[f"base_model.model.{self.peft_key}.lokr_w2_b"] = self.lokr_w2_b
|
|
|
|
if self.lokr_t2 is not None:
|
|
state_dict[f"base_model.model.{self.peft_key}.lokr_t2"] = self.lokr_t2
|
|
|
|
return state_dict
|
|
|
|
|
|
def construct_peft_loraconfig(info: dict[str, LoRAInfo], **kwargs) -> LoraConfig:
|
|
"""Constructs LoraConfig from data extracted from adapter checkpoint
|
|
|
|
Args:
|
|
info (Dict[str, LoRAInfo]): Information extracted from adapter checkpoint
|
|
|
|
Returns:
|
|
LoraConfig: config for constructing LoRA
|
|
"""
|
|
|
|
# Unpack all ranks and alphas
|
|
ranks = {key: val.rank for key, val in info.items()}
|
|
alphas = {x[0]: x[1].alpha or x[1].rank for x in info.items()}
|
|
|
|
# Determine which modules needs to be transformed
|
|
target_modules = sorted(info.keys())
|
|
|
|
# Determine most common rank and alpha
|
|
r = int(Counter(ranks.values()).most_common(1)[0][0])
|
|
lora_alpha = Counter(alphas.values()).most_common(1)[0][0]
|
|
|
|
# Determine which modules have different rank and alpha
|
|
rank_pattern = dict(sorted(filter(lambda x: x[1] != r, ranks.items()), key=lambda x: x[0]))
|
|
alpha_pattern = dict(sorted(filter(lambda x: x[1] != lora_alpha, alphas.items()), key=lambda x: x[0]))
|
|
|
|
config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=lora_alpha,
|
|
target_modules=target_modules,
|
|
lora_dropout=0.0,
|
|
bias="none",
|
|
init_lora_weights=False,
|
|
rank_pattern=rank_pattern,
|
|
alpha_pattern=alpha_pattern,
|
|
)
|
|
|
|
return config
|
|
|
|
|
|
def construct_peft_lohaconfig(info: dict[str, LoHaInfo], **kwargs) -> LoHaConfig:
|
|
"""Constructs LoHaConfig from data extracted from adapter checkpoint
|
|
|
|
Args:
|
|
info (Dict[str, LoHaInfo]): Information extracted from adapter checkpoint
|
|
|
|
Returns:
|
|
LoHaConfig: config for constructing LoHA
|
|
"""
|
|
|
|
# Unpack all ranks and alphas
|
|
ranks = {x[0]: x[1].rank for x in info.items()}
|
|
alphas = {x[0]: x[1].alpha or x[1].rank for x in info.items()}
|
|
|
|
# Determine which modules needs to be transformed
|
|
target_modules = sorted(info.keys())
|
|
|
|
# Determine most common rank and alpha
|
|
r = int(Counter(ranks.values()).most_common(1)[0][0])
|
|
alpha = Counter(alphas.values()).most_common(1)[0][0]
|
|
|
|
# Determine which modules have different rank and alpha
|
|
rank_pattern = dict(sorted(filter(lambda x: x[1] != r, ranks.items()), key=lambda x: x[0]))
|
|
alpha_pattern = dict(sorted(filter(lambda x: x[1] != alpha, alphas.items()), key=lambda x: x[0]))
|
|
|
|
# Determine whether any of modules have effective conv2d decomposition
|
|
use_effective_conv2d = any((val.hada_t1 is not None) or (val.hada_t2 is not None) for val in info.values())
|
|
|
|
config = LoHaConfig(
|
|
r=r,
|
|
alpha=alpha,
|
|
target_modules=target_modules,
|
|
rank_dropout=0.0,
|
|
module_dropout=0.0,
|
|
init_weights=False,
|
|
rank_pattern=rank_pattern,
|
|
alpha_pattern=alpha_pattern,
|
|
use_effective_conv2d=use_effective_conv2d,
|
|
)
|
|
|
|
return config
|
|
|
|
|
|
def construct_peft_lokrconfig(info: dict[str, LoKrInfo], decompose_factor: int = -1, **kwargs) -> LoKrConfig:
|
|
"""Constructs LoKrConfig from data extracted from adapter checkpoint
|
|
|
|
Args:
|
|
info (Dict[str, LoKrInfo]): Information extracted from adapter checkpoint
|
|
|
|
Returns:
|
|
LoKrConfig: config for constructing LoKr
|
|
"""
|
|
|
|
# Unpack all ranks and alphas
|
|
ranks = {x[0]: x[1].rank for x in info.items()}
|
|
alphas = {x[0]: x[1].alpha or x[1].rank for x in info.items()}
|
|
|
|
# Determine which modules needs to be transformed
|
|
target_modules = sorted(info.keys())
|
|
|
|
# Determine most common rank and alpha
|
|
r = int(Counter(ranks.values()).most_common(1)[0][0])
|
|
alpha = Counter(alphas.values()).most_common(1)[0][0]
|
|
|
|
# Determine which modules have different rank and alpha
|
|
rank_pattern = dict(sorted(filter(lambda x: x[1] != r, ranks.items()), key=lambda x: x[0]))
|
|
alpha_pattern = dict(sorted(filter(lambda x: x[1] != alpha, alphas.items()), key=lambda x: x[0]))
|
|
|
|
# Determine whether any of modules have effective conv2d decomposition
|
|
use_effective_conv2d = any((val.lokr_t2 is not None) for val in info.values())
|
|
|
|
# decompose_both should be enabled if any w1 matrix in any layer is decomposed into 2
|
|
decompose_both = any((val.lokr_w1_a is not None and val.lokr_w1_b is not None) for val in info.values())
|
|
|
|
# Determining decompose factor is a bit tricky (but it is most often -1)
|
|
# Check that decompose_factor is equal to provided
|
|
for val in info.values():
|
|
# Determine shape of first matrix
|
|
if val.lokr_w1 is not None:
|
|
w1_shape = tuple(val.lokr_w1.shape)
|
|
else:
|
|
w1_shape = (val.lokr_w1_a.shape[0], val.lokr_w1_b.shape[1])
|
|
|
|
# Determine shape of second matrix
|
|
if val.lokr_w2 is not None:
|
|
w2_shape = tuple(val.lokr_w2.shape[:2])
|
|
elif val.lokr_t2 is not None:
|
|
w2_shape = (val.lokr_w2_a.shape[1], val.lokr_w2_b.shape[1])
|
|
else:
|
|
# We may iterate over Conv2d layer, for which second item in shape is multiplied by ksize^2
|
|
w2_shape = (val.lokr_w2_a.shape[0], val.lokr_w2_b.shape[1])
|
|
|
|
# We need to check, whether decompose_factor is really -1 or not
|
|
shape = (w1_shape[0], w2_shape[0])
|
|
if factorization(shape[0] * shape[1], factor=-1) != shape:
|
|
raise ValueError("Cannot infer decompose_factor, probably it is not equal to -1")
|
|
|
|
config = LoKrConfig(
|
|
r=r,
|
|
alpha=alpha,
|
|
target_modules=target_modules,
|
|
rank_dropout=0.0,
|
|
module_dropout=0.0,
|
|
init_weights=False,
|
|
rank_pattern=rank_pattern,
|
|
alpha_pattern=alpha_pattern,
|
|
use_effective_conv2d=use_effective_conv2d,
|
|
decompose_both=decompose_both,
|
|
decompose_factor=decompose_factor,
|
|
)
|
|
|
|
return config
|
|
|
|
|
|
def combine_peft_state_dict(info: dict[str, Union[LoRAInfo, LoHaInfo]]) -> dict[str, torch.Tensor]:
|
|
result = {}
|
|
for key_info in info.values():
|
|
result.update(key_info.peft_state_dict())
|
|
return result
|
|
|
|
|
|
def detect_adapter_type(keys: list[str]) -> PeftType:
|
|
# Detect type of adapter by keys
|
|
# Inspired by this:
|
|
# https://github.com/bmaltais/kohya_ss/blob/ed4e3b0239a40506de9a17e550e6cf2d0b867a4f/tools/lycoris_utils.py#L312
|
|
for key in keys:
|
|
if "alpha" in key:
|
|
continue
|
|
elif any(x in key for x in ["lora_down", "lora_up"]):
|
|
# LoRA
|
|
return PeftType.LORA
|
|
elif any(x in key for x in ["hada_w1", "hada_w2", "hada_t1", "hada_t2"]):
|
|
# LoHa may have the following keys:
|
|
# hada_w1_a, hada_w1_b, hada_w2_a, hada_w2_b, hada_t1, hada_t2
|
|
return PeftType.LOHA
|
|
elif any(x in key for x in ["lokr_w1", "lokr_w2", "lokr_t1", "lokr_t2"]):
|
|
# LoKr may have the following keys:
|
|
# lokr_w1, lokr_w2, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t1, lokr_t2
|
|
return PeftType.LOKR
|
|
elif "diff" in key:
|
|
raise ValueError("Currently full diff adapters are not implemented")
|
|
else:
|
|
raise ValueError("Unknown adapter type, probably not implemented")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--sd_checkpoint", default=None, type=str, required=True, help="SD checkpoint to use")
|
|
|
|
parser.add_argument(
|
|
"--adapter_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to downloaded adapter to convert",
|
|
)
|
|
|
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output peft adapter.")
|
|
|
|
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
|
parser.add_argument(
|
|
"--loha_conv2d_weights_fix",
|
|
action="store_true",
|
|
help="""LoHa checkpoints trained with lycoris-lora<=1.9.0 contain a bug described in this PR https://github.com/KohakuBlueleaf/LyCORIS/pull/115.
|
|
This option fixes this bug during weight conversion (replaces hada_t2 with hada_t1 for Conv2d 3x3 layers).
|
|
The output results may differ from webui, but in general, they should be better in terms of quality.
|
|
This option should be set to True in case the provided checkpoint has been trained with lycoris-lora version for which the mentioned PR wasn't merged.
|
|
This option should be set to False in case the provided checkpoint has been trained with lycoris-lora version for which the mentioned PR is merged or full compatibility with webui outputs is required.""",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Load all models that we need to add adapter to
|
|
text_encoder = CLIPTextModel.from_pretrained(args.sd_checkpoint, subfolder="text_encoder")
|
|
unet = UNet2DConditionModel.from_pretrained(args.sd_checkpoint, subfolder="unet")
|
|
|
|
# Construct possible mapping from kohya keys to peft keys
|
|
models_keys = {}
|
|
for model, model_key, model_name in [
|
|
(text_encoder, PREFIX_TEXT_ENCODER, "text_encoder"),
|
|
(unet, PREFIX_UNET, "unet"),
|
|
]:
|
|
models_keys.update(
|
|
{
|
|
f"{model_key}.{peft_key}".replace(".", "_"): peft_key
|
|
for peft_key in (x[0] for x in model.named_modules())
|
|
}
|
|
)
|
|
|
|
# Store conversion info (model_type -> peft_key -> LoRAInfo | LoHaInfo | LoKrInfo)
|
|
adapter_info: dict[str, dict[str, Union[LoRAInfo, LoHaInfo, LoKrInfo]]] = {
|
|
"text_encoder": {},
|
|
"unet": {},
|
|
}
|
|
|
|
# Store decompose_factor for LoKr
|
|
decompose_factor = -1
|
|
|
|
# Open adapter checkpoint
|
|
with safetensors.safe_open(args.adapter_path, framework="pt", device="cpu") as f:
|
|
# Extract information about adapter structure
|
|
metadata = f.metadata()
|
|
|
|
# It may be difficult to determine rank for LoKr adapters
|
|
# If checkpoint was trained with large rank it may not be utilized during weights creation at all
|
|
# So we need to get it from checkpoint metadata (along with decompose_factor)
|
|
rank, conv_rank = None, None
|
|
if metadata is not None:
|
|
rank = metadata.get("ss_network_dim", None)
|
|
rank = int(rank) if rank else None
|
|
if "ss_network_args" in metadata:
|
|
network_args = json.loads(metadata["ss_network_args"])
|
|
conv_rank = network_args.get("conv_dim", None)
|
|
conv_rank = int(conv_rank) if conv_rank else rank
|
|
decompose_factor = network_args.get("factor", -1)
|
|
decompose_factor = int(decompose_factor)
|
|
|
|
# Detect adapter type based on keys
|
|
adapter_type = detect_adapter_type(f.keys())
|
|
adapter_info_cls = {
|
|
PeftType.LORA: LoRAInfo,
|
|
PeftType.LOHA: LoHaInfo,
|
|
PeftType.LOKR: LoKrInfo,
|
|
}[adapter_type]
|
|
|
|
# Iterate through available info and unpack all the values
|
|
for key in f.keys():
|
|
kohya_key, kohya_type = key.split(".")[:2]
|
|
|
|
# Find which model this key belongs to
|
|
if kohya_key.startswith(PREFIX_TEXT_ENCODER):
|
|
model_type, model = "text_encoder", text_encoder
|
|
elif kohya_key.startswith(PREFIX_UNET):
|
|
model_type, model = "unet", unet
|
|
else:
|
|
raise ValueError(f"Cannot determine model for key: {key}")
|
|
|
|
# Find corresponding peft key
|
|
if kohya_key not in models_keys:
|
|
raise ValueError(f"Cannot find corresponding key for diffusers/transformers model: {kohya_key}")
|
|
peft_key = models_keys[kohya_key]
|
|
|
|
# Retrieve corresponding layer of model
|
|
layer = attrgetter(peft_key)(model)
|
|
|
|
# Create a corresponding adapter info
|
|
if peft_key not in adapter_info[model_type]:
|
|
adapter_info[model_type][peft_key] = adapter_info_cls(kohya_key=kohya_key, peft_key=peft_key)
|
|
|
|
tensor = f.get_tensor(key)
|
|
if kohya_type == "alpha":
|
|
adapter_info[model_type][peft_key].alpha = tensor.item()
|
|
elif kohya_type == "lora_down":
|
|
adapter_info[model_type][peft_key].lora_A = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
elif kohya_type == "lora_up":
|
|
adapter_info[model_type][peft_key].lora_B = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[1]
|
|
elif kohya_type == "hada_w1_a":
|
|
adapter_info[model_type][peft_key].hada_w1_a = tensor
|
|
elif kohya_type == "hada_w1_b":
|
|
adapter_info[model_type][peft_key].hada_w1_b = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
elif kohya_type == "hada_w2_a":
|
|
adapter_info[model_type][peft_key].hada_w2_a = tensor
|
|
elif kohya_type == "hada_w2_b":
|
|
adapter_info[model_type][peft_key].hada_w2_b = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
elif kohya_type in {"hada_t1", "hada_t2"}:
|
|
if args.loha_conv2d_weights_fix:
|
|
if kohya_type == "hada_t1":
|
|
# This code block fixes a bug that exists for some LoHa checkpoints
|
|
# that resulted in accidentally using hada_t1 weight instead of hada_t2, see
|
|
# https://github.com/KohakuBlueleaf/LyCORIS/pull/115
|
|
adapter_info[model_type][peft_key].hada_t1 = tensor
|
|
adapter_info[model_type][peft_key].hada_t2 = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
else:
|
|
if kohya_type == "hada_t1":
|
|
adapter_info[model_type][peft_key].hada_t1 = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
elif kohya_type == "hada_t2":
|
|
adapter_info[model_type][peft_key].hada_t2 = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
elif kohya_type == "lokr_t2":
|
|
adapter_info[model_type][peft_key].lokr_t2 = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
elif kohya_type == "lokr_w1":
|
|
adapter_info[model_type][peft_key].lokr_w1 = tensor
|
|
if isinstance(layer, nn.Linear) or (
|
|
isinstance(layer, nn.Conv2d) and tuple(layer.weight.shape[2:]) == (1, 1)
|
|
):
|
|
adapter_info[model_type][peft_key].rank = rank
|
|
elif isinstance(layer, nn.Conv2d):
|
|
adapter_info[model_type][peft_key].rank = conv_rank
|
|
elif kohya_type == "lokr_w2":
|
|
adapter_info[model_type][peft_key].lokr_w2 = tensor
|
|
if isinstance(layer, nn.Linear) or (
|
|
isinstance(layer, nn.Conv2d) and tuple(layer.weight.shape[2:]) == (1, 1)
|
|
):
|
|
adapter_info[model_type][peft_key].rank = rank
|
|
elif isinstance(layer, nn.Conv2d):
|
|
adapter_info[model_type][peft_key].rank = conv_rank
|
|
elif kohya_type == "lokr_w1_a":
|
|
adapter_info[model_type][peft_key].lokr_w1_a = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[1]
|
|
elif kohya_type == "lokr_w1_b":
|
|
adapter_info[model_type][peft_key].lokr_w1_b = tensor
|
|
adapter_info[model_type][peft_key].rank = tensor.shape[0]
|
|
elif kohya_type == "lokr_w2_a":
|
|
adapter_info[model_type][peft_key].lokr_w2_a = tensor
|
|
elif kohya_type == "lokr_w2_b":
|
|
adapter_info[model_type][peft_key].lokr_w2_b = tensor
|
|
else:
|
|
raise ValueError(f"Unknown weight name in key: {key} - {kohya_type}")
|
|
|
|
# Get function which will create adapter config based on extracted info
|
|
construct_config_fn = {
|
|
PeftType.LORA: construct_peft_loraconfig,
|
|
PeftType.LOHA: construct_peft_lohaconfig,
|
|
PeftType.LOKR: construct_peft_lokrconfig,
|
|
}[adapter_type]
|
|
|
|
# Process each model sequentially
|
|
for model, model_name in [(text_encoder, "text_encoder"), (unet, "unet")]:
|
|
# Skip model if no data was provided
|
|
if len(adapter_info[model_name]) == 0:
|
|
continue
|
|
|
|
config = construct_config_fn(adapter_info[model_name], decompose_factor=decompose_factor)
|
|
|
|
# Output warning for LoHa with use_effective_conv2d
|
|
if (
|
|
isinstance(config, LoHaConfig)
|
|
and getattr(config, "use_effective_conv2d", False)
|
|
and args.loha_conv2d_weights_fix is False
|
|
):
|
|
logging.warning(
|
|
'lycoris-lora<=1.9.0 LoHa implementation contains a bug, which can be fixed with "--loha_conv2d_weights_fix".\n'
|
|
"For more info, please refer to https://github.com/huggingface/peft/pull/1021 and https://github.com/KohakuBlueleaf/LyCORIS/pull/115"
|
|
)
|
|
|
|
model = get_peft_model(model, config)
|
|
missing_keys, unexpected_keys = set_peft_model_state_dict(
|
|
model, combine_peft_state_dict(adapter_info[model_name])
|
|
)
|
|
if len(unexpected_keys) > 0:
|
|
raise ValueError(f"Unexpected keys {unexpected_keys} found during conversion")
|
|
|
|
if args.half:
|
|
model.to(torch.float16)
|
|
|
|
# Save model to disk
|
|
model.save_pretrained(os.path.join(args.dump_path, model_name))
|