mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 23:43:47 +08:00
Compare commits
25 Commits
v0.6.1
...
docs_relat
Author | SHA1 | Date | |
---|---|---|---|
b169484659 | |||
8351331d78 | |||
f1ecfa6ae6 | |||
b5a8a294ed | |||
9cdaed2769 | |||
18a0910113 | |||
99e1a55f54 | |||
21df968fd1 | |||
5a3a5acff2 | |||
70302d7b4f | |||
3ff90626b6 | |||
1877329093 | |||
98429b8184 | |||
d350a00ece | |||
ad756173f1 | |||
94877b5008 | |||
f020404ee6 | |||
79298c7c24 | |||
b25ce8a0cd | |||
5d84484079 | |||
49ddefa834 | |||
3af469eeea | |||
5e7e5ad836 | |||
9d8287f3e3 | |||
2efd02769b |
4
.github/workflows/nightly.yml
vendored
4
.github/workflows/nightly.yml
vendored
@ -15,6 +15,8 @@ env:
|
||||
|
||||
jobs:
|
||||
run_all_tests_single_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on: [self-hosted, docker-gpu, multi-gpu]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
@ -57,6 +59,8 @@ jobs:
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_all_tests_multi_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on: [self-hosted, docker-gpu, multi-gpu]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
|
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
needs: check_code_quality
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
|
@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Some fine-tuning techniques, such as prompt tuning, are specific to language models. That means in 🤗 PEFT, it is
|
||||
assumed a 🤗 Transformers model is being used. However, other fine-tuning techniques - like
|
||||
[LoRA](./conceptual_guides/lora) - are not restricted to specific model types.
|
||||
[LoRA](../conceptual_guides/lora) - are not restricted to specific model types.
|
||||
|
||||
In this guide, we will see how LoRA can be applied to a multilayer perceptron and a computer vision model from the [timm](https://huggingface.co/docs/timm/index) library.
|
||||
|
||||
|
@ -17,7 +17,7 @@ The development of this API has been motivated by the need for super users to no
|
||||
|
||||
## Supported tuner types
|
||||
|
||||
Currently the supported adapter types are the 'injectable' adapters, meaning adapters where an inplace modification of the model is sufficient to correctly perform the fine tuning. As such, only [LoRA](./conceptual_guides/lora), AdaLoRA and [IA3](./conceptual_guides/ia3) are currently supported in this API.
|
||||
Currently the supported adapter types are the 'injectable' adapters, meaning adapters where an inplace modification of the model is sufficient to correctly perform the fine tuning. As such, only [LoRA](../conceptual_guides/lora), AdaLoRA and [IA3](../conceptual_guides/ia3) are currently supported in this API.
|
||||
|
||||
## `inject_adapter_in_model` method
|
||||
|
||||
|
@ -83,6 +83,7 @@ accelerate launch train_dreambooth.py \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--train_text_encoder \
|
||||
--with_prior_preservation --prior_loss_weight=1.0 \
|
||||
--num_dataloader_workers=1 \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
@ -101,6 +102,8 @@ accelerate launch train_dreambooth.py \
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
If you are running this script on Windows, you may need to set the `--num_dataloader_workers` to 0.
|
||||
|
||||
## Inference with a single adapter
|
||||
|
||||
To run inference with the fine-tuned model, first specify the base model with which the fine-tuned LoRA weights will be combined:
|
||||
@ -171,7 +174,7 @@ image.save("DESTINATION_PATH_FOR_THE_IMAGE")
|
||||
## Multi-adapter inference
|
||||
|
||||
With PEFT you can combine multiple adapters for inference. In the previous example you have fine-tuned Stable Diffusion on
|
||||
some dog images. The pipeline created based on these weights got a name - `adapter_name="dog`. Now, suppose you also fine-tuned
|
||||
some dog images. The pipeline created based on these weights got a name - `adapter_name="dog"`. Now, suppose you also fine-tuned
|
||||
this base model on images of a crochet toy. Let's see how we can use both adapters.
|
||||
|
||||
First, you'll need to perform all the steps as in the single adapter inference example:
|
||||
|
@ -7,6 +7,7 @@ import math
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -213,6 +214,17 @@ def parse_args(input_args=None):
|
||||
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_dataloader_workers", type=int, default=1, help="Num of workers for the training dataloader."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_tracemalloc",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to stop memory allocation tracing during training. This could speed up training on Windows.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
@ -799,7 +811,7 @@ def main(args):
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||
num_workers=1,
|
||||
num_workers=args.num_dataloader_workers,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
@ -893,7 +905,7 @@ def main(args):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
with TorchTracemalloc() as tracemalloc:
|
||||
with TorchTracemalloc() if not args.no_tracemalloc else nullcontext() as tracemalloc:
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@ -1034,23 +1046,29 @@ def main(args):
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
|
||||
accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
|
||||
accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
|
||||
accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
|
||||
accelerator.print(
|
||||
"GPU Total Peak Memory consumed during the train (max): {}".format(
|
||||
tracemalloc.peaked + b2mb(tracemalloc.begin)
|
||||
)
|
||||
)
|
||||
|
||||
accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
|
||||
accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used))
|
||||
accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked))
|
||||
accelerator.print(
|
||||
"CPU Total Peak Memory consumed during the train (max): {}".format(
|
||||
tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
|
||||
if not args.no_tracemalloc:
|
||||
accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
|
||||
accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
|
||||
accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
|
||||
accelerator.print(
|
||||
"GPU Total Peak Memory consumed during the train (max): {}".format(
|
||||
tracemalloc.peaked + b2mb(tracemalloc.begin)
|
||||
)
|
||||
)
|
||||
|
||||
accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
|
||||
accelerator.print(
|
||||
"CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used)
|
||||
)
|
||||
accelerator.print(
|
||||
"CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked)
|
||||
)
|
||||
accelerator.print(
|
||||
"CPU Total Peak Memory consumed during the train (max): {}".format(
|
||||
tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
|
3
setup.py
3
setup.py
@ -22,7 +22,7 @@ extras["test"] = extras["dev"] + ["pytest", "pytest-cov", "pytest-xdist", "param
|
||||
|
||||
setup(
|
||||
name="peft",
|
||||
version="0.6.1.dev0",
|
||||
version="0.6.3.dev0",
|
||||
description="Parameter-Efficient Fine-Tuning (PEFT)",
|
||||
license_files=["LICENSE"],
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
@ -47,6 +47,7 @@ setup(
|
||||
"tqdm",
|
||||
"accelerate>=0.21.0",
|
||||
"safetensors",
|
||||
"huggingface_hub>=0.17.0",
|
||||
],
|
||||
extras_require=extras,
|
||||
classifiers=[
|
||||
|
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.6.1.dev0"
|
||||
__version__ = "0.6.3.dev0"
|
||||
|
||||
from .auto import (
|
||||
AutoPeftModel,
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import importlib.metadata as importlib_metadata
|
||||
from functools import lru_cache
|
||||
|
||||
import packaging.version
|
||||
|
||||
@ -46,3 +47,20 @@ def is_auto_gptq_available():
|
||||
|
||||
def is_optimum_available() -> bool:
|
||||
return importlib.util.find_spec("optimum") is not None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_torch_tpu_available(check_device=True):
|
||||
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
||||
if importlib.util.find_spec("torch_xla") is not None:
|
||||
if check_device:
|
||||
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
_ = xm.xla_device()
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
@ -32,7 +32,6 @@ from safetensors.torch import save_file as safe_save_file
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
from transformers.utils import PushToHubMixin
|
||||
|
||||
from . import __version__
|
||||
@ -60,6 +59,7 @@ from .utils import (
|
||||
_set_adapter,
|
||||
_set_trainable,
|
||||
get_peft_model_state_dict,
|
||||
id_tensor_storage,
|
||||
infer_device,
|
||||
load_peft_weights,
|
||||
set_peft_model_state_dict,
|
||||
@ -157,7 +157,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str,
|
||||
safe_serialization: bool = False,
|
||||
safe_serialization: bool = True,
|
||||
selected_adapters: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@ -573,7 +573,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
self.base_model.add_adapter(adapter_name, peft_config)
|
||||
else:
|
||||
self.peft_config[adapter_name] = peft_config
|
||||
self.base_model.inject_adapter(self, adapter_name)
|
||||
self.base_model.inject_adapter(self.base_model.model, adapter_name)
|
||||
except Exception: # somthing went wrong, roll back
|
||||
if adapter_name in self.peft_config:
|
||||
del self.peft_config[adapter_name]
|
||||
|
@ -27,10 +27,3 @@ from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparamet
|
||||
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
|
||||
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
|
||||
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
|
||||
|
||||
# Mapping of tuners that support direct plugging
|
||||
TUNERS_MAPPING = {
|
||||
"LORA": LoraModel,
|
||||
"IA3": IA3Model,
|
||||
"ADALORA": AdaLoraModel,
|
||||
}
|
||||
|
@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
@ -23,38 +24,28 @@ from .layer import AdaLoraLayer
|
||||
|
||||
if is_bnb_available():
|
||||
|
||||
class SVDLinear8bitLt(bnb.nn.Linear8bitLt, AdaLoraLayer):
|
||||
class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer):
|
||||
# Low-rank matrix for SVD-based adaptation
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name,
|
||||
in_features,
|
||||
out_features,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
bnb.nn.Linear8bitLt.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=kwargs.get("bias", True),
|
||||
has_fp16_weights=kwargs.get("has_fp16_weights", True),
|
||||
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
|
||||
threshold=kwargs.get("threshold", 0.0),
|
||||
index=kwargs.get("index", None),
|
||||
)
|
||||
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
|
||||
super().__init__()
|
||||
AdaLoraLayer.__init__(self, base_layer)
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.get_base_layer().weight.requires_grad = False
|
||||
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
result = super().forward(x)
|
||||
# note: no check for self.merged because merging is not supported (yet)
|
||||
result = self.base_layer(x)
|
||||
|
||||
if self.disable_adapters:
|
||||
return result
|
||||
@ -79,43 +70,39 @@ if is_bnb_available():
|
||||
if requires_conversion:
|
||||
output = output.to(expected_dtype)
|
||||
output = output * scaling / ranknum
|
||||
result += output
|
||||
# inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it
|
||||
result = result + output
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "adalora." + rep
|
||||
|
||||
|
||||
if is_bnb_4bit_available():
|
||||
|
||||
class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer):
|
||||
class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
|
||||
# Low-rank matrix for SVD-based adaptation
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name,
|
||||
in_features,
|
||||
out_features,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
bnb.nn.Linear4bit.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=kwargs.get("bias", True),
|
||||
compute_dtype=kwargs.get("compute_dtype", torch.float32),
|
||||
compress_statistics=kwargs.get("compress_statistics", True),
|
||||
quant_type=kwargs.get("quant_type", "nf4"),
|
||||
)
|
||||
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
|
||||
super().__init__()
|
||||
AdaLoraLayer.__init__(self, base_layer)
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.get_base_layer().weight.requires_grad = False
|
||||
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
result = super().forward(x)
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
# note: no check for self.merged because merging is not supported (yet)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
if self.disable_adapters:
|
||||
return result
|
||||
@ -141,7 +128,7 @@ if is_bnb_4bit_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
compute_dtype = lora_A.weight.dtype
|
||||
compute_dtype = lora_A.dtype
|
||||
if x.dtype != compute_dtype:
|
||||
x = x.to(compute_dtype)
|
||||
|
||||
@ -151,3 +138,7 @@ if is_bnb_4bit_available():
|
||||
output = output * scaling / ranknum
|
||||
result += output
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "adalora." + rep
|
||||
|
@ -20,22 +20,21 @@ from .layer import AdaLoraLayer
|
||||
class SVDQuantLinear(torch.nn.Module, AdaLoraLayer):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
quant_linear_module,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
AdaLoraLayer.__init__(
|
||||
self, in_features=quant_linear_module.infeatures, out_features=quant_linear_module.outfeatures
|
||||
)
|
||||
self.quant_linear_module = quant_linear_module
|
||||
self.weight = quant_linear_module.qweight
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
super().__init__()
|
||||
AdaLoraLayer.__init__(self, base_layer)
|
||||
|
||||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
||||
# for backwards compatibility
|
||||
self.quant_linear_module = base_layer
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
result = self.quant_linear_module(x)
|
||||
@ -67,3 +66,7 @@ class SVDQuantLinear(torch.nn.Module, AdaLoraLayer):
|
||||
output = output.to(expected_dtype)
|
||||
result += output
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "adalora." + rep
|
||||
|
@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from peft.tuners.lora import LoraLayer
|
||||
@ -26,14 +26,11 @@ from peft.utils import transpose
|
||||
class AdaLoraLayer(LoraLayer):
|
||||
# List all names of layers that may contain adapter weights
|
||||
# Note: ranknum doesn't need to be included as it is not an nn.Module
|
||||
adapter_layer_names = ["lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B"]
|
||||
adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B")
|
||||
# other_param_names is defined in LoraLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
):
|
||||
super().__init__(in_features, out_features)
|
||||
def __init__(self, base_layer: nn.Module) -> None:
|
||||
super().__init__(base_layer)
|
||||
self.lora_E = nn.ParameterDict({})
|
||||
self.lora_A = nn.ParameterDict({})
|
||||
self.lora_B = nn.ParameterDict({})
|
||||
@ -62,7 +59,12 @@ class AdaLoraLayer(LoraLayer):
|
||||
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
|
||||
if init_lora_weights:
|
||||
self.reset_lora_parameters(adapter_name)
|
||||
self.to(self.weight.device)
|
||||
|
||||
if hasattr(self.get_base_layer(), "qweight"):
|
||||
# QuantLinear
|
||||
self.to(self.get_base_layer().qweight.device)
|
||||
else:
|
||||
self.to(self.get_base_layer().weight.device)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def reset_lora_parameters(self, adapter_name):
|
||||
@ -72,34 +74,29 @@ class AdaLoraLayer(LoraLayer):
|
||||
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02)
|
||||
|
||||
|
||||
class SVDLinear(nn.Linear, AdaLoraLayer):
|
||||
class SVDLinear(nn.Module, AdaLoraLayer):
|
||||
# SVD-based adaptation by a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
|
||||
super().__init__()
|
||||
AdaLoraLayer.__init__(self, base_layer)
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.get_base_layer().weight.requires_grad = False
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
nn.Linear.reset_parameters(self)
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def merge(self, safe_merge: bool = False) -> None:
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
@ -108,18 +105,26 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
base_layer = self.get_base_layer()
|
||||
if active_adapter in self.lora_A.keys():
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weights = self.weight.data.clone()
|
||||
orig_weights = base_layer.weight.data.clone()
|
||||
orig_weights += self.get_delta_weight(active_adapter)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
@ -127,9 +132,9 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
self.weight.data = orig_weights
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
self.weight.data += self.get_delta_weight(active_adapter)
|
||||
base_layer.weight.data += self.get_delta_weight(active_adapter)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
@ -139,7 +144,7 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.lora_A.keys():
|
||||
self.weight.data -= self.get_delta_weight(active_adapter)
|
||||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
||||
|
||||
def get_delta_weight(self, adapter) -> torch.Tensor:
|
||||
return (
|
||||
@ -148,19 +153,16 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
|
||||
/ (self.ranknum[adapter] + 1e-5)
|
||||
)
|
||||
|
||||
def _linear(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
# TODO: SVDLinear does not convert dtype, unlike lora linear, is that correct?
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
@ -175,8 +177,12 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "adalora." + rep
|
||||
|
||||
class RankAllocator(object):
|
||||
|
||||
class RankAllocator:
|
||||
"""
|
||||
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY
|
||||
|
||||
|
@ -20,6 +20,7 @@ from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
from peft.tuners.lora import LoraConfig, LoraModel
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils import (
|
||||
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
|
||||
_freeze_adapter,
|
||||
@ -67,6 +68,8 @@ class AdaLoraModel(LoraModel):
|
||||
- **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model.
|
||||
"""
|
||||
|
||||
# Note: don't redefine prefix here, it should be inherited from LoraModel
|
||||
|
||||
def __init__(self, model, config, adapter_name):
|
||||
super().__init__(model, config, adapter_name)
|
||||
|
||||
@ -121,7 +124,7 @@ class AdaLoraModel(LoraModel):
|
||||
loaded_in_4bit = optional_kwargs.get("loaded_in_4bit", False)
|
||||
if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available():
|
||||
raise ImportError(
|
||||
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
|
||||
"To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. "
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
kwargs = {
|
||||
@ -138,7 +141,7 @@ class AdaLoraModel(LoraModel):
|
||||
if quantization_config is not None:
|
||||
kwargs["gptq_quantization_config"] = quantization_config
|
||||
|
||||
# If it is not a LoraLayer, create a new module, else update it with new adapters
|
||||
# If it is not an AdaLoraLayer, create a new module, else update it with new adapters
|
||||
if not isinstance(target, AdaLoraLayer):
|
||||
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
|
||||
if adapter_name != self.active_adapter:
|
||||
@ -159,11 +162,15 @@ class AdaLoraModel(LoraModel):
|
||||
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
|
||||
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
|
||||
|
||||
bias = target.bias is not None
|
||||
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
|
||||
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
|
||||
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
|
||||
kwargs.update(
|
||||
{
|
||||
"has_fp16_weights": target.state.has_fp16_weights,
|
||||
@ -172,8 +179,8 @@ class AdaLoraModel(LoraModel):
|
||||
"index": target.index,
|
||||
}
|
||||
)
|
||||
new_module = SVDLinear8bitLt(adapter_name, target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
|
||||
new_module = SVDLinear8bitLt(target, adapter_name, **kwargs)
|
||||
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
|
||||
fourbit_kwargs = kwargs.copy()
|
||||
fourbit_kwargs.update(
|
||||
{
|
||||
@ -182,25 +189,18 @@ class AdaLoraModel(LoraModel):
|
||||
"quant_type": target.weight.quant_type,
|
||||
}
|
||||
)
|
||||
new_module = SVDLinear4bit(
|
||||
adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
|
||||
)
|
||||
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
|
||||
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
|
||||
new_module = SVDQuantLinear(adapter_name, target, **kwargs)
|
||||
target.weight = target.qweight
|
||||
new_module = SVDQuantLinear(target, adapter_name, **kwargs)
|
||||
else:
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
in_features, out_features = target.in_features, target.out_features
|
||||
if isinstance(target_base_layer, torch.nn.Linear):
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||
elif isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
|
||||
)
|
||||
elif isinstance(target_base_layer, Conv1D):
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
@ -212,7 +212,7 @@ class AdaLoraModel(LoraModel):
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
|
||||
)
|
||||
new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs)
|
||||
new_module = SVDLinear(target, adapter_name, **kwargs)
|
||||
|
||||
return new_module
|
||||
|
||||
@ -236,7 +236,7 @@ class AdaLoraModel(LoraModel):
|
||||
def forward(self, *args, **kwargs):
|
||||
outputs = self.model.forward(*args, **kwargs)
|
||||
|
||||
if getattr(outputs, "loss", None) is not None:
|
||||
if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor):
|
||||
# Calculate the orthogonal regularization
|
||||
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight
|
||||
|
||||
|
@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
@ -23,39 +24,27 @@ from .layer import IA3Layer
|
||||
|
||||
if is_bnb_available():
|
||||
|
||||
class Linear8bitLt(bnb.nn.Linear8bitLt, IA3Layer):
|
||||
class Linear8bitLt(torch.nn.Module, IA3Layer):
|
||||
# (IA)^3 implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name,
|
||||
in_features,
|
||||
out_features,
|
||||
is_feedforward,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
is_feedforward: bool,
|
||||
init_ia3_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
bnb.nn.Linear8bitLt.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=kwargs.get("bias", True),
|
||||
has_fp16_weights=kwargs.get("has_fp16_weights", True),
|
||||
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
|
||||
threshold=kwargs.get("threshold", 0.0),
|
||||
index=kwargs.get("index", None),
|
||||
)
|
||||
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
|
||||
self.is_feedforward = is_feedforward
|
||||
super().__init__()
|
||||
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
|
||||
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
init_ia3_weights = kwargs.pop("init_ia3_weights", True)
|
||||
self.get_base_layer().weight.requires_grad = False
|
||||
self.update_layer(adapter_name, init_ia3_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
# note: no check for self.merged because merging is not supported (yet)
|
||||
if self.disable_adapters:
|
||||
return super().forward(x)
|
||||
return self.base_layer(x)
|
||||
|
||||
ia3_scaling = 1
|
||||
for active_adapter in self.active_adapters:
|
||||
@ -67,10 +56,10 @@ if is_bnb_available():
|
||||
if requires_conversion:
|
||||
x = x.float()
|
||||
if self.is_feedforward:
|
||||
result = super().forward(x * ia3_scaling)
|
||||
result = self.base_layer(x * ia3_scaling)
|
||||
expected_dtype = result.dtype
|
||||
else:
|
||||
result = super().forward(x)
|
||||
result = self.base_layer(x)
|
||||
expected_dtype = result.dtype
|
||||
result = result * ia3_scaling
|
||||
|
||||
@ -79,41 +68,34 @@ if is_bnb_available():
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "ia3." + rep
|
||||
|
||||
|
||||
if is_bnb_4bit_available():
|
||||
|
||||
class Linear4bit(bnb.nn.Linear4bit, IA3Layer):
|
||||
class Linear4bit(torch.nn.Module, IA3Layer):
|
||||
# IA3 implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name,
|
||||
in_features,
|
||||
out_features,
|
||||
is_feedforward,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
is_feedforward: bool,
|
||||
init_ia3_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
bnb.nn.Linear4bit.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=kwargs.get("bias", True),
|
||||
compute_dtype=kwargs.get("compute_dtype", torch.float32),
|
||||
compress_statistics=kwargs.get("compress_statistics", True),
|
||||
quant_type=kwargs.get("quant_type", "nf4"),
|
||||
)
|
||||
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
|
||||
self.is_feedforward = is_feedforward
|
||||
super().__init__()
|
||||
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
|
||||
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
init_ia3_weights = kwargs.pop("init_ia3_weights", True)
|
||||
self.get_base_layer().weight.requires_grad = False
|
||||
self.update_layer(adapter_name, init_ia3_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
# note: no check for self.merged because merging is not supported (yet)
|
||||
if self.disable_adapters:
|
||||
return super().forward(x)
|
||||
return self.base_layer(x)
|
||||
|
||||
ia3_scaling = 1
|
||||
for active_adapter in self.active_adapters:
|
||||
@ -125,10 +107,10 @@ if is_bnb_4bit_available():
|
||||
if requires_conversion:
|
||||
x = x.float()
|
||||
if self.is_feedforward:
|
||||
result = super().forward(x * ia3_scaling)
|
||||
result = self.base_layer(x * ia3_scaling)
|
||||
expected_dtype = result.dtype
|
||||
else:
|
||||
result = super().forward(x)
|
||||
result = self.base_layer(x)
|
||||
expected_dtype = result.dtype
|
||||
result = result * ia3_scaling
|
||||
|
||||
@ -140,3 +122,7 @@ if is_bnb_4bit_available():
|
||||
result = result.to(expected_dtype)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "ia3." + rep
|
||||
|
@ -14,34 +14,43 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils import transpose
|
||||
|
||||
|
||||
class IA3Layer(BaseTunerLayer):
|
||||
# List all names of layers that may contain adapter weights
|
||||
adapter_layer_names = ["ia3_l"]
|
||||
# All names of layers that may contain adapter weights
|
||||
adapter_layer_names = ("ia3_l",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
is_feedforward: bool,
|
||||
):
|
||||
self.scaling = {}
|
||||
def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
|
||||
self.base_layer = base_layer
|
||||
self.ia3_l = nn.ParameterDict({})
|
||||
# Mark the weight as unmerged
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
self.is_feedforward = is_feedforward
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
in_features, out_features = base_layer.in_channels, base_layer.out_channels
|
||||
elif isinstance(base_layer, nn.Embedding):
|
||||
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
|
||||
elif isinstance(base_layer, Conv1D):
|
||||
in_features, out_features = (
|
||||
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(base_layer)}")
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.is_feedforward = is_feedforward
|
||||
|
||||
def update_layer(self, adapter_name, init_ia3_weights):
|
||||
# Actual trainable parameters
|
||||
@ -52,7 +61,7 @@ class IA3Layer(BaseTunerLayer):
|
||||
self.ia3_l[adapter_name] = nn.Parameter(weight)
|
||||
if init_ia3_weights:
|
||||
self.reset_ia3_parameters(adapter_name)
|
||||
self.to(self.weight.device)
|
||||
self.to(self.get_base_layer().weight.device)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def reset_ia3_parameters(self, adapter_name):
|
||||
@ -61,35 +70,24 @@ class IA3Layer(BaseTunerLayer):
|
||||
nn.init.constant_(self.ia3_l[adapter_name], 1.0)
|
||||
|
||||
|
||||
class Linear(nn.Linear, IA3Layer):
|
||||
class Linear(nn.Module, IA3Layer):
|
||||
# (IA)^3 implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
|
||||
is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later
|
||||
init_ia3_weights: bool = True, # whether to initialize IA3 weights
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_ia3_weights = kwargs.pop("init_ia3_weights", True)
|
||||
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
|
||||
self.is_feedforward = is_feedforward
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
super().__init__()
|
||||
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
|
||||
nn.Linear.reset_parameters(self)
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, init_ia3_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def update_layer(self, adapter_name, init_ia3_weights):
|
||||
# Actual trainable parameters
|
||||
@ -100,10 +98,10 @@ class Linear(nn.Linear, IA3Layer):
|
||||
self.ia3_l[adapter_name] = nn.Parameter(weight)
|
||||
if init_ia3_weights:
|
||||
self.reset_ia3_parameters(adapter_name)
|
||||
self.to(self.weight.device)
|
||||
self.to(self.get_base_layer().weight.device)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def merge(self, safe_merge: bool = False) -> None:
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
@ -112,6 +110,9 @@ class Linear(nn.Linear, IA3Layer):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
@ -119,26 +120,28 @@ class Linear(nn.Linear, IA3Layer):
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out)
|
||||
if safe_merge:
|
||||
orig_weights = transpose(self.weight, self.fan_in_fan_out).clone()
|
||||
orig_weights = torch.mul(orig_weights.data, self.ia3_l[active_adapter].data)
|
||||
orig_weights = base_layer.weight.data
|
||||
orig_weights = torch.mul(orig_weights, ia3_l)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
self.weight.data = orig_weights
|
||||
self.weight = transpose(self.weight, self.fan_in_fan_out)
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
self.weight = transpose(self.weight, self.fan_in_fan_out)
|
||||
self.weight.data = torch.mul(self.weight.data, self.ia3_l[active_adapter].data)
|
||||
self.weight = transpose(self.weight, self.fan_in_fan_out)
|
||||
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l)
|
||||
|
||||
if not self.is_feedforward and (self.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape)
|
||||
self.bias.data = torch.mul(self.bias.data, scaling.data)
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)
|
||||
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
@ -151,27 +154,24 @@ class Linear(nn.Linear, IA3Layer):
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
self.weight = transpose(self.weight, self.fan_in_fan_out)
|
||||
# divide by (IA)^3 vector. Add tolerace to avoid division by zero
|
||||
self.weight.data = torch.div(self.weight.data, self.ia3_l[active_adapter].data + 1e-8)
|
||||
self.weight = transpose(self.weight, self.fan_in_fan_out)
|
||||
base_layer = self.get_base_layer()
|
||||
# Add tolerace to avoid division by zero
|
||||
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8
|
||||
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l)
|
||||
|
||||
if not self.is_feedforward and (self.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape)
|
||||
self.bias.data = torch.div(self.bias.data, scaling.data + 1e-8)
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8)
|
||||
|
||||
def _linear(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
dtype = previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
ia3_scaling = 1
|
||||
for active_adapter in self.active_adapters:
|
||||
@ -182,46 +182,34 @@ class Linear(nn.Linear, IA3Layer):
|
||||
|
||||
if self.is_feedforward:
|
||||
x = x.to(dtype)
|
||||
# TODO: self.weight.dtype can be != self.ia3_l[self.active_adapters].dtype
|
||||
# TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
|
||||
# e.g. bf16 vs fp32. Is that okay?
|
||||
interm = (x * ia3_scaling).to(self.weight.dtype)
|
||||
result = self._linear(interm)
|
||||
interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype)
|
||||
result = self.base_layer(interm, *args, **kwargs)
|
||||
else:
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
result = result.to(dtype) * ia3_scaling
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
|
||||
class Conv2d(nn.Conv2d, IA3Layer):
|
||||
class Conv2d(nn.Module, IA3Layer):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
|
||||
init_ia3_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_ia3_weights = kwargs.pop("init_ia3_weights", True)
|
||||
|
||||
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
IA3Layer.__init__(self, in_features=in_channels, out_features=out_channels, is_feedforward=is_feedforward)
|
||||
self.is_feedforward = is_feedforward
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
super().__init__()
|
||||
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
self._active_adapter = adapter_name
|
||||
|
||||
nn.Conv2d.reset_parameters(self)
|
||||
self.update_layer(adapter_name, init_ia3_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def update_layer(self, adapter_name, init_ia3_weights):
|
||||
# Actual trainable parameters
|
||||
@ -232,10 +220,10 @@ class Conv2d(nn.Conv2d, IA3Layer):
|
||||
self.ia3_l[adapter_name] = nn.Parameter(weight)
|
||||
if init_ia3_weights:
|
||||
self.reset_ia3_parameters(adapter_name)
|
||||
self.to(self.weight.device)
|
||||
self.to(self.get_base_layer().weight.device)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def merge(self, safe_merge: bool = False) -> None:
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
@ -244,6 +232,9 @@ class Conv2d(nn.Conv2d, IA3Layer):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
@ -251,27 +242,31 @@ class Conv2d(nn.Conv2d, IA3Layer):
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
ia3_scaling = self.ia3_l[active_adapter].data
|
||||
if not self.is_feedforward:
|
||||
ia3_scaling = ia3_scaling.permute(1, 0, 2, 3)
|
||||
|
||||
if safe_merge:
|
||||
output_weight = torch.mul(self.weight.data, ia3_scaling).clone()
|
||||
output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone()
|
||||
|
||||
if not torch.isfinite(output_weight).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
self.weight.data = output_weight
|
||||
base_layer.weight.data = output_weight
|
||||
else:
|
||||
self.weight.data = torch.mul(self.weight.data, ia3_scaling)
|
||||
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling)
|
||||
|
||||
if not self.is_feedforward and (self.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape)
|
||||
self.bias.data = torch.mul(self.bias.data, scaling.data)
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)
|
||||
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
@ -284,36 +279,26 @@ class Conv2d(nn.Conv2d, IA3Layer):
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
# divide by (IA)^3 vector. Add tolerace to avoid division by zero
|
||||
ia3_scaling = self.ia3_l[active_adapter].data
|
||||
if not self.is_feedforward:
|
||||
ia3_scaling = ia3_scaling.permute(1, 0, 2, 3)
|
||||
self.weight.data = torch.div(self.weight.data, ia3_scaling + 1e-8)
|
||||
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8)
|
||||
|
||||
if not self.is_feedforward and (self.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape)
|
||||
self.bias.data = torch.mul(self.bias.data, scaling.data)
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)
|
||||
|
||||
def _conv2d(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.conv2d(
|
||||
input,
|
||||
self.weight,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
dtype = previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self._conv2d(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self._conv2d(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
ia3_scaling = 1
|
||||
for active_adapter in self.active_adapters:
|
||||
@ -324,12 +309,12 @@ class Conv2d(nn.Conv2d, IA3Layer):
|
||||
|
||||
if self.is_feedforward:
|
||||
x = x.to(dtype)
|
||||
# TODO: self.weight.dtype can be != self.ia3_l[self.active_adapters].dtype
|
||||
# TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
|
||||
# e.g. bf16 vs fp32. Is that okay?
|
||||
interm = (x * ia3_scaling).to(self.weight.dtype)
|
||||
result = self._conv2d(interm)
|
||||
interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype)
|
||||
result = self.base_layer(interm, *args, **kwargs)
|
||||
else:
|
||||
result = self._conv2d(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
result = result.to(dtype) * ia3_scaling
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
|
@ -17,12 +17,13 @@ import re
|
||||
import warnings
|
||||
from dataclasses import asdict
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
from peft.tuners.tuners_utils import BaseTuner, check_target_module_exists
|
||||
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
|
||||
from peft.utils import (
|
||||
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
|
||||
@ -77,17 +78,23 @@ class IA3Model(BaseTuner):
|
||||
- **peft_config** ([`ia3Config`]): The configuration of the (IA)^3 model.
|
||||
"""
|
||||
|
||||
prefix: str = "ia3_"
|
||||
|
||||
def __init__(self, model, config, adapter_name):
|
||||
super().__init__(model, config, adapter_name)
|
||||
|
||||
@staticmethod
|
||||
def _create_new_module(ia3_config, adapter_name, target, **kwargs):
|
||||
bias = hasattr(target, "bias") and target.bias is not None
|
||||
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
|
||||
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
|
||||
is_feedforward = kwargs.pop("is_feedforward", False)
|
||||
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
|
||||
eightbit_kwargs = kwargs.copy()
|
||||
eightbit_kwargs.update(
|
||||
{
|
||||
@ -97,15 +104,8 @@ class IA3Model(BaseTuner):
|
||||
"index": target.index,
|
||||
}
|
||||
)
|
||||
new_module = Linear8bitLt(
|
||||
adapter_name,
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
is_feedforward,
|
||||
bias=bias,
|
||||
**eightbit_kwargs,
|
||||
)
|
||||
elif loaded_in_4bit and isinstance(target, bnb.nn.Linear4bit):
|
||||
new_module = Linear8bitLt(target, adapter_name, is_feedforward=is_feedforward, **eightbit_kwargs)
|
||||
elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit):
|
||||
fourbit_kwargs = kwargs.copy()
|
||||
fourbit_kwargs.update(
|
||||
{
|
||||
@ -114,56 +114,31 @@ class IA3Model(BaseTuner):
|
||||
"quant_type": target.weight.quant_type,
|
||||
}
|
||||
)
|
||||
new_module = Linear4bit(
|
||||
adapter_name,
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
is_feedforward,
|
||||
bias=bias,
|
||||
**fourbit_kwargs,
|
||||
)
|
||||
new_module = Linear4bit(target, adapter_name, is_feedforward=is_feedforward, **fourbit_kwargs)
|
||||
elif isinstance(target, torch.nn.Conv2d):
|
||||
out_channels, in_channels = target.weight.size()[:2]
|
||||
kernel_size = target.weight.size()[2:]
|
||||
stride = target.stride
|
||||
padding = target.padding
|
||||
new_module = Conv2d(
|
||||
adapter_name=adapter_name,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
is_feedforward=is_feedforward,
|
||||
**kwargs,
|
||||
new_module = Conv2d(target, adapter_name, is_feedforward=is_feedforward, **kwargs)
|
||||
elif isinstance(target_base_layer, torch.nn.Linear):
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False
|
||||
new_module = Linear(target, adapter_name, is_feedforward=is_feedforward, **kwargs)
|
||||
elif isinstance(target_base_layer, Conv1D):
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
"Setting fan_in_fan_out to True."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True
|
||||
new_module = Linear(
|
||||
target, adapter_name, is_feedforward=is_feedforward, is_target_conv_1d_layer=True, **kwargs
|
||||
)
|
||||
else:
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
in_features, out_features = target.in_features, target.out_features
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False
|
||||
elif isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
|
||||
)
|
||||
kwargs["is_target_conv_1d_layer"] = True # useful for unloading later
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
"Setting fan_in_fan_out to True."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported."
|
||||
)
|
||||
new_module = Linear(
|
||||
adapter_name, in_features, out_features, is_feedforward=is_feedforward, bias=bias, **kwargs
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported."
|
||||
)
|
||||
return new_module
|
||||
|
||||
@ -173,7 +148,7 @@ class IA3Model(BaseTuner):
|
||||
|
||||
def _mark_only_adapters_as_trainable(self) -> None:
|
||||
for n, p in self.model.named_parameters():
|
||||
if "ia3_" not in n:
|
||||
if self.prefix not in n:
|
||||
p.requires_grad = False
|
||||
|
||||
def _create_and_replace(
|
||||
@ -200,21 +175,16 @@ class IA3Model(BaseTuner):
|
||||
"is_feedforward": is_feedforward,
|
||||
}
|
||||
|
||||
if isinstance(target, IA3Layer):
|
||||
if target.is_feedforward != is_feedforward:
|
||||
raise ValueError(
|
||||
"New adapter should have the same value for `is_feedforward` as previously added adapter."
|
||||
)
|
||||
if isinstance(target, torch.nn.Conv2d):
|
||||
target.update_layer_conv2d(
|
||||
adapter_name,
|
||||
ia3_config.init_ia3_weights,
|
||||
)
|
||||
else: # Linear
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
ia3_config.init_ia3_weights,
|
||||
)
|
||||
if isinstance(target, Conv2d):
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
ia3_config.init_ia3_weights,
|
||||
)
|
||||
elif isinstance(target, Linear):
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
ia3_config.init_ia3_weights,
|
||||
)
|
||||
else:
|
||||
new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs)
|
||||
if adapter_name != self.active_adapter:
|
||||
@ -234,19 +204,29 @@ class IA3Model(BaseTuner):
|
||||
is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules)
|
||||
return is_feedforward
|
||||
|
||||
@staticmethod
|
||||
def _replace_module(parent, child_name, new_module, child):
|
||||
def _replace_module(self, parent, child_name, new_module, child):
|
||||
setattr(parent, child_name, new_module)
|
||||
new_module.weight = child.weight
|
||||
if child.bias is not None:
|
||||
new_module.bias = child.bias
|
||||
|
||||
# child layer wraps the original module, unpack it
|
||||
if hasattr(child, "base_layer"):
|
||||
child = child.base_layer
|
||||
|
||||
# layers with base_layer don't need the weight to be copied, as they have a reference already
|
||||
if not hasattr(new_module, "base_layer"):
|
||||
new_module.weight = child.weight
|
||||
if hasattr(child, "bias"):
|
||||
new_module.bias = child.bias
|
||||
|
||||
if getattr(child, "state", None) is not None:
|
||||
new_module.state = child.state
|
||||
if hasattr(new_module, "base_layer"):
|
||||
new_module.base_layer.state = child.state
|
||||
else:
|
||||
new_module.state = child.state
|
||||
new_module.to(child.weight.device)
|
||||
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "ia3_" in name:
|
||||
if self.prefix in name:
|
||||
module.to(child.weight.device)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
@ -297,7 +277,9 @@ class IA3Model(BaseTuner):
|
||||
]
|
||||
return peft_config
|
||||
|
||||
def merge_and_unload(self, safe_merge: bool = False):
|
||||
def _unload_and_optionally_merge(
|
||||
self, merge: bool = True, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
|
||||
):
|
||||
r"""
|
||||
This method merges the (IA)^3 layers into the base model. This is needed if someone wants to use the base model
|
||||
as a standalone model.
|
||||
@ -307,6 +289,9 @@ class IA3Model(BaseTuner):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
if getattr(self.model, "is_loaded_in_8bit", False):
|
||||
raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode")
|
||||
@ -314,38 +299,75 @@ class IA3Model(BaseTuner):
|
||||
if getattr(self.model, "is_loaded_in_4bit", False):
|
||||
raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode")
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
|
||||
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||
for key in key_list:
|
||||
try:
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
if isinstance(target, ModulesToSaveWrapper):
|
||||
if hasattr(target, "base_layer"):
|
||||
if merge:
|
||||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
||||
elif isinstance(target, ModulesToSaveWrapper):
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
|
||||
continue
|
||||
|
||||
if not isinstance(target, IA3Layer):
|
||||
continue
|
||||
|
||||
if isinstance(target, torch.nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
target.in_channels,
|
||||
target.out_channels,
|
||||
kernel_size=target.kernel_size,
|
||||
stride=target.stride,
|
||||
padding=target.padding,
|
||||
dilation=target.dilation,
|
||||
)
|
||||
else:
|
||||
bias = target.bias is not None
|
||||
if getattr(target, "is_target_conv_1d_layer", False):
|
||||
new_module = Conv1D(target.out_features, target.in_features)
|
||||
else:
|
||||
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
|
||||
|
||||
target.merge(safe_merge=safe_merge)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
||||
return self.model
|
||||
|
||||
def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None):
|
||||
r"""
|
||||
This method merges the IA³ layers into the base model. This is needed if someone wants to use the base model as
|
||||
a standalone model.
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`):
|
||||
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
||||
weights
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForCausalLM
|
||||
>>> from peft import PeftModel
|
||||
|
||||
>>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
|
||||
>>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample"
|
||||
>>> model = PeftModel.from_pretrained(base_model, peft_model_id)
|
||||
>>> merged_model = model.merge_and_unload()
|
||||
```
|
||||
"""
|
||||
return self._unload_and_optionally_merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
|
||||
def unload(self):
|
||||
"""
|
||||
Gets back the base model by removing all the IA³ modules without merging. This gives back the original base
|
||||
model.
|
||||
"""
|
||||
return self._unload_and_optionally_merge(merge=False)
|
||||
|
||||
def delete_adapter(self, adapter_name: str):
|
||||
"""
|
||||
Deletes an existing adapter.
|
||||
|
||||
Args:
|
||||
adapter_name (str): Name of the adapter to be deleted.
|
||||
"""
|
||||
if adapter_name not in self.peft_config:
|
||||
raise ValueError(f"Adapter {adapter_name} does not exist")
|
||||
del self.peft_config[adapter_name]
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||
new_adapter = None
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(self.model, key)
|
||||
if isinstance(target, IA3Layer):
|
||||
target.delete_adapter(adapter_name)
|
||||
if new_adapter is None:
|
||||
new_adapter = target.active_adapters[:]
|
||||
|
||||
self.active_adapter = new_adapter or []
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Any, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -23,13 +23,14 @@ import torch.nn.functional as F
|
||||
from peft.tuners.lycoris_utils import LycorisLayer
|
||||
|
||||
|
||||
class LoHaLayer(LycorisLayer, nn.Module):
|
||||
# List all names of layers that may contain adapter weights
|
||||
adapter_layer_names = ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2"]
|
||||
class LoHaLayer(nn.Module, LycorisLayer):
|
||||
# All names of layers that may contain adapter weights
|
||||
adapter_layer_names = ("hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2")
|
||||
# other_param_names is defined on parent class
|
||||
|
||||
def __init__(self):
|
||||
LycorisLayer.__init__(self)
|
||||
super(nn.Module, self).__init__()
|
||||
def __init__(self, base_layer: nn.Module):
|
||||
super().__init__()
|
||||
LycorisLayer.__init__(self, base_layer)
|
||||
|
||||
# LoHa info
|
||||
self.hada_w1_a = nn.ParameterDict({})
|
||||
@ -75,6 +76,21 @@ class LoHaLayer(LycorisLayer, nn.Module):
|
||||
nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5))
|
||||
|
||||
def reset_adapter_parameters_random(self, adapter_name: str):
|
||||
# Original implementation performs initialization with normal distribution
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/blob/3549fdef8f564761d68b695a08ef88b1122fdedc/lycoris/modules/loha.py#L158
|
||||
|
||||
# FedPara paper proposes to perform He initialization, let's stick with it
|
||||
# It is enough to initialize only single matrix with zeros to make adapter do nothing after initialization
|
||||
if adapter_name in self.hada_w1_a.keys():
|
||||
nn.init.kaiming_uniform_(self.hada_w1_a[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.hada_w1_b[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.hada_w2_a[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.hada_w2_b[adapter_name], a=math.sqrt(5))
|
||||
if adapter_name in self.hada_t1.keys():
|
||||
nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5))
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name: str,
|
||||
@ -106,16 +122,20 @@ class LoHaLayer(LycorisLayer, nn.Module):
|
||||
self.module_dropout[adapter_name] = module_dropout
|
||||
|
||||
# Determine shape of LoHa weights
|
||||
if isinstance(self, nn.Linear):
|
||||
shape = tuple(self.weight.shape)
|
||||
elif isinstance(self, nn.Conv2d):
|
||||
use_effective_conv2d = use_effective_conv2d and self.kernel_size != (1, 1)
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
shape = tuple(base_layer.weight.shape)
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
|
||||
if use_effective_conv2d:
|
||||
shape = (self.out_channels, self.in_channels, *self.kernel_size)
|
||||
shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size)
|
||||
else:
|
||||
shape = (self.out_channels, self.in_channels * self.kernel_size[0] * self.kernel_size[1])
|
||||
shape = (
|
||||
base_layer.out_channels,
|
||||
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"LoHa is not implemented for {type(self).__name__} layer")
|
||||
raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}")
|
||||
|
||||
# Create weights with provided shape
|
||||
self.create_adapter_parameters(adapter_name, r, shape)
|
||||
@ -123,9 +143,11 @@ class LoHaLayer(LycorisLayer, nn.Module):
|
||||
# Initialize weights
|
||||
if init_weights:
|
||||
self.reset_adapter_parameters(adapter_name)
|
||||
else:
|
||||
self.reset_adapter_parameters_random(adapter_name)
|
||||
|
||||
# Move new weights to device
|
||||
weight = getattr(self, "weight", None)
|
||||
weight = getattr(self.get_base_layer(), "weight", None)
|
||||
if weight is not None:
|
||||
# the layer is already completely initialized, this is an update
|
||||
if weight.dtype.is_floating_point or weight.dtype.is_complex:
|
||||
@ -155,7 +177,8 @@ class LoHaLayer(LycorisLayer, nn.Module):
|
||||
scale=torch.tensor(self.scaling[adapter_name]),
|
||||
)
|
||||
|
||||
weight = weight.reshape(self.weight.shape)
|
||||
base_layer = self.get_base_layer()
|
||||
weight = weight.reshape(base_layer.weight.shape)
|
||||
|
||||
# Perform rank dropout during training - drop rows of addition weights
|
||||
rank_dropout = self.rank_dropout[adapter_name]
|
||||
@ -170,96 +193,107 @@ class LoHaLayer(LycorisLayer, nn.Module):
|
||||
|
||||
return weight
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
class Linear(LoHaLayer, nn.Linear):
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
# Execute all the adapters
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self._available_adapters:
|
||||
continue
|
||||
|
||||
module_dropout = self.module_dropout[active_adapter]
|
||||
|
||||
# Modify current execution weights
|
||||
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
|
||||
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
|
||||
class Linear(LoHaLayer):
|
||||
"""LoHa implemented in Linear layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str = "default",
|
||||
r: int = 0,
|
||||
alpha: float = 0.0,
|
||||
rank_dropout: float = 0.0,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
init_weights = kwargs.pop("init_weights", True)
|
||||
self._init_empty_weights(nn.Linear, in_features, out_features, bias, device=device, dtype=dtype)
|
||||
|
||||
LoHaLayer.__init__(self)
|
||||
super().__init__(base_layer)
|
||||
|
||||
# Create adapter and set it active
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, weight, bias=self.bias)
|
||||
def _get_delta_activations(
|
||||
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
delta_weight = self.get_delta_weight(adapter_name)
|
||||
# don't add bias here, because the bias is already included in the output of the base_layer
|
||||
return F.linear(input, delta_weight)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "loha." + rep
|
||||
|
||||
|
||||
class Conv2d(LoHaLayer, nn.Conv2d):
|
||||
class Conv2d(LoHaLayer):
|
||||
"""LoHa implemented in Conv2d layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str = "default",
|
||||
r: int = 0,
|
||||
alpha: float = 0.0,
|
||||
rank_dropout: float = 0.0,
|
||||
module_dropout: float = 0.0,
|
||||
use_effective_conv2d: bool = False,
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
init_weights = kwargs.pop("init_weights", True)
|
||||
self._init_empty_weights(
|
||||
nn.Conv2d,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
LoHaLayer.__init__(self)
|
||||
super().__init__(base_layer)
|
||||
|
||||
# Create adapter and set it active
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
|
||||
)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
def _get_delta_activations(
|
||||
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
delta_weight = self.get_delta_weight(adapter_name)
|
||||
# don't add bias here, because the bias is already included in the output of the base_layer
|
||||
base_layer = self.get_base_layer()
|
||||
return F.conv2d(
|
||||
input,
|
||||
weight,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
delta_weight,
|
||||
stride=base_layer.stride,
|
||||
padding=base_layer.padding,
|
||||
dilation=base_layer.dilation,
|
||||
groups=base_layer.groups,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "loha." + rep
|
||||
|
||||
|
||||
# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9
|
||||
|
||||
|
@ -13,11 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Type
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
|
||||
|
||||
from ..lycoris_utils import LycorisTuner
|
||||
from .layer import Conv2d, Linear, LoHaLayer
|
||||
|
||||
|
||||
@ -82,3 +86,31 @@ class LoHaModel(LycorisTuner):
|
||||
torch.nn.Conv2d: Conv2d,
|
||||
torch.nn.Linear: Linear,
|
||||
}
|
||||
|
||||
def _create_and_replace(
|
||||
self,
|
||||
config: LycorisConfig,
|
||||
adapter_name: str,
|
||||
target: Union[LoHaLayer, nn.Module],
|
||||
target_name: str,
|
||||
parent: nn.Module,
|
||||
current_key: str,
|
||||
**optional_kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
A private method to create and replace the target module with the adapter module.
|
||||
"""
|
||||
|
||||
# Regexp matching - Find key which matches current target_name in patterns provided
|
||||
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
|
||||
target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name)
|
||||
|
||||
kwargs = config.to_dict()
|
||||
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
|
||||
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
|
||||
|
||||
if isinstance(target, LoHaLayer):
|
||||
target.update_layer(adapter_name, **kwargs)
|
||||
else:
|
||||
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Any, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -23,9 +23,9 @@ import torch.nn.functional as F
|
||||
from peft.tuners.lycoris_utils import LycorisLayer
|
||||
|
||||
|
||||
class LoKrLayer(LycorisLayer, nn.Module):
|
||||
# List all names of layers that may contain adapter weights
|
||||
adapter_layer_names = [
|
||||
class LoKrLayer(nn.Module, LycorisLayer):
|
||||
# All names of layers that may contain adapter weights
|
||||
adapter_layer_names = (
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
@ -33,11 +33,12 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
]
|
||||
)
|
||||
# other_param_names is defined on parent class
|
||||
|
||||
def __init__(self):
|
||||
LycorisLayer.__init__(self)
|
||||
super(nn.Module, self).__init__()
|
||||
def __init__(self, base_layer: nn.Module) -> None:
|
||||
super().__init__()
|
||||
LycorisLayer.__init__(self, base_layer)
|
||||
|
||||
# LoKr info
|
||||
self.lokr_w1 = nn.ParameterDict({})
|
||||
@ -110,6 +111,22 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
if adapter_name in self.lokr_t2:
|
||||
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
|
||||
|
||||
def reset_adapter_parameters_random(self, adapter_name: str):
|
||||
if adapter_name in self.lokr_w1:
|
||||
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5))
|
||||
else:
|
||||
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5))
|
||||
|
||||
if adapter_name in self.lokr_w2:
|
||||
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5))
|
||||
else:
|
||||
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5))
|
||||
|
||||
if adapter_name in self.lokr_t2:
|
||||
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name: str,
|
||||
@ -142,10 +159,11 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
self.scaling[adapter_name] = alpha / r
|
||||
self.rank_dropout[adapter_name] = rank_dropout
|
||||
self.module_dropout[adapter_name] = module_dropout
|
||||
base_layer = self.get_base_layer()
|
||||
|
||||
# Determine shape of LoKr weights
|
||||
if isinstance(self, nn.Linear):
|
||||
in_dim, out_dim = self.in_features, self.out_features
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
in_dim, out_dim = base_layer.in_features, base_layer.out_features
|
||||
|
||||
in_m, in_n = factorization(in_dim, decompose_factor)
|
||||
out_l, out_k = factorization(out_dim, decompose_factor)
|
||||
@ -154,9 +172,9 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
|
||||
use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2)
|
||||
use_effective_conv2d = False
|
||||
elif isinstance(self, nn.Conv2d):
|
||||
in_dim, out_dim = self.in_channels, self.out_channels
|
||||
k_size = self.kernel_size
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
in_dim, out_dim = base_layer.in_channels, base_layer.out_channels
|
||||
k_size = base_layer.kernel_size
|
||||
|
||||
in_m, in_n = factorization(in_dim, decompose_factor)
|
||||
out_l, out_k = factorization(out_dim, decompose_factor)
|
||||
@ -164,9 +182,9 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
|
||||
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
|
||||
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
|
||||
use_effective_conv2d = use_effective_conv2d and self.kernel_size != (1, 1)
|
||||
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
|
||||
else:
|
||||
raise TypeError(f"LoKr is not implemented for {type(self).__name__} layer")
|
||||
raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}")
|
||||
|
||||
# Create weights with provided shape
|
||||
self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d)
|
||||
@ -174,9 +192,11 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
# Initialize weights
|
||||
if init_weights:
|
||||
self.reset_adapter_parameters(adapter_name)
|
||||
else:
|
||||
self.reset_adapter_parameters_random(adapter_name)
|
||||
|
||||
# Move new weights to device
|
||||
weight = getattr(self, "weight", None)
|
||||
weight = getattr(self.get_base_layer(), "weight", None)
|
||||
if weight is not None:
|
||||
# the layer is already completely initialized, this is an update
|
||||
if weight.dtype.is_floating_point or weight.dtype.is_complex:
|
||||
@ -201,7 +221,7 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
|
||||
# Make weights with Kronecker product
|
||||
weight = make_kron(w1, w2)
|
||||
weight = weight.reshape(self.weight.shape)
|
||||
weight = weight.reshape(self.get_base_layer().weight.shape)
|
||||
|
||||
# Perform rank dropout during training - drop rows of addition weights
|
||||
rank_dropout = self.rank_dropout[adapter_name]
|
||||
@ -213,15 +233,39 @@ class LoKrLayer(LycorisLayer, nn.Module):
|
||||
|
||||
return weight
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
class Linear(LoKrLayer, nn.Linear):
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
# Execute all the adapters
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self._available_adapters:
|
||||
continue
|
||||
|
||||
module_dropout = self.module_dropout[active_adapter]
|
||||
|
||||
# Modify current execution weights
|
||||
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
|
||||
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
|
||||
class Linear(LoKrLayer):
|
||||
"""LoKr implemented in Linear layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
base_layer: nn.Module,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
adapter_name: str = "default",
|
||||
@ -229,35 +273,33 @@ class Linear(LoKrLayer, nn.Linear):
|
||||
alpha: float = 0.0,
|
||||
rank_dropout: float = 0.0,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
init_weights = kwargs.pop("init_weights", True)
|
||||
self._init_empty_weights(nn.Linear, in_features, out_features, bias, device=device, dtype=dtype)
|
||||
|
||||
LoKrLayer.__init__(self)
|
||||
super().__init__(base_layer)
|
||||
|
||||
# Create adapter and set it active
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, weight, bias=self.bias)
|
||||
def _get_delta_activations(
|
||||
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
delta_weight = self.get_delta_weight(adapter_name)
|
||||
# don't add bias here, because the bias is already included in the output of the base_layer
|
||||
return F.linear(input, delta_weight)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lokr." + rep
|
||||
|
||||
|
||||
class Conv2d(LoKrLayer, nn.Conv2d):
|
||||
class Conv2d(LoKrLayer):
|
||||
"""LoKr implemented in Conv2d layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
base_layer: nn.Module,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
adapter_name: str = "default",
|
||||
@ -266,43 +308,36 @@ class Conv2d(LoKrLayer, nn.Conv2d):
|
||||
rank_dropout: float = 0.0,
|
||||
module_dropout: float = 0.0,
|
||||
use_effective_conv2d: bool = False,
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
init_weights = kwargs.pop("init_weights", True)
|
||||
self._init_empty_weights(
|
||||
nn.Conv2d,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
LoKrLayer.__init__(self)
|
||||
super().__init__(base_layer)
|
||||
|
||||
# Create adapter and set it active
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
|
||||
)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
def _get_delta_activations(
|
||||
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
delta_weight = self.get_delta_weight(adapter_name)
|
||||
# don't add bias here, because the bias is already included in the output of the base_layer
|
||||
base_layer = self.get_base_layer()
|
||||
return F.conv2d(
|
||||
input,
|
||||
weight,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
delta_weight,
|
||||
stride=base_layer.stride,
|
||||
padding=base_layer.padding,
|
||||
dilation=base_layer.dilation,
|
||||
groups=base_layer.groups,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lokr." + rep
|
||||
|
||||
|
||||
# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11
|
||||
|
||||
|
@ -13,11 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Type
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
|
||||
|
||||
from ..lycoris_utils import LycorisTuner
|
||||
from .layer import Conv2d, Linear, LoKrLayer
|
||||
|
||||
|
||||
@ -83,3 +87,31 @@ class LoKrModel(LycorisTuner):
|
||||
torch.nn.Conv2d: Conv2d,
|
||||
torch.nn.Linear: Linear,
|
||||
}
|
||||
|
||||
def _create_and_replace(
|
||||
self,
|
||||
config: LycorisConfig,
|
||||
adapter_name: str,
|
||||
target: Union[LoKrLayer, nn.Module],
|
||||
target_name: str,
|
||||
parent: nn.Module,
|
||||
current_key: str,
|
||||
**optional_kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
A private method to create and replace the target module with the adapter module.
|
||||
"""
|
||||
|
||||
# Regexp matching - Find key which matches current target_name in patterns provided
|
||||
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
|
||||
target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name)
|
||||
|
||||
kwargs = config.to_dict()
|
||||
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
|
||||
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
|
||||
|
||||
if isinstance(target, LoKrLayer):
|
||||
target.update_layer(adapter_name, **kwargs)
|
||||
else:
|
||||
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
@ -30,22 +31,20 @@ if is_bnb_available():
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name,
|
||||
base_layer,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features)
|
||||
self.base_layer = base_layer
|
||||
LoraLayer.__init__(self, base_layer)
|
||||
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def merge(self, safe_merge: bool = False):
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
@ -54,6 +53,9 @@ if is_bnb_available():
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
@ -61,7 +63,10 @@ if is_bnb_available():
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
warnings.warn(
|
||||
@ -69,8 +74,8 @@ if is_bnb_available():
|
||||
)
|
||||
lora_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
weight = self.base_layer.weight
|
||||
state = self.base_layer.state
|
||||
weight = self.get_base_layer().weight
|
||||
state = self.get_base_layer().state
|
||||
if state.SCB is None:
|
||||
state.SCB = weight.SCB
|
||||
|
||||
@ -90,7 +95,7 @@ if is_bnb_available():
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
self.base_layer.weight = bnb.nn.Int8Params(
|
||||
self.get_base_layer().weight = bnb.nn.Int8Params(
|
||||
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
|
||||
).to(weight.device)
|
||||
state.reset_grads()
|
||||
@ -110,8 +115,8 @@ if is_bnb_available():
|
||||
)
|
||||
lora_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
weight = self.base_layer.weight
|
||||
state = self.base_layer.state
|
||||
weight = self.get_base_layer().weight
|
||||
state = self.get_base_layer().state
|
||||
if state.SCB is None:
|
||||
state.SCB = weight.SCB
|
||||
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
|
||||
@ -124,7 +129,7 @@ if is_bnb_available():
|
||||
output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
|
||||
|
||||
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
|
||||
self.base_layer.weight = bnb.nn.Int8Params(
|
||||
self.get_base_layer().weight = bnb.nn.Int8Params(
|
||||
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
|
||||
).to(weight.device)
|
||||
state.reset_grads()
|
||||
@ -169,6 +174,10 @@ if is_bnb_available():
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora." + rep
|
||||
|
||||
|
||||
if is_bnb_4bit_available():
|
||||
|
||||
@ -176,22 +185,20 @@ if is_bnb_4bit_available():
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name,
|
||||
base_layer,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features)
|
||||
self.base_layer = base_layer
|
||||
LoraLayer.__init__(self, base_layer)
|
||||
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def merge(self, safe_merge: bool = False):
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
@ -200,6 +207,9 @@ if is_bnb_4bit_available():
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
@ -207,14 +217,17 @@ if is_bnb_4bit_available():
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
warnings.warn(
|
||||
"Merge lora module to 4-bit linear may get different generations due to rounding errors."
|
||||
)
|
||||
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
|
||||
weight = self.base_layer.weight
|
||||
weight = self.get_base_layer().weight
|
||||
kwargs = weight.__dict__
|
||||
lora_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
@ -224,7 +237,7 @@ if is_bnb_4bit_available():
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
|
||||
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
|
||||
weight.device
|
||||
)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
@ -241,11 +254,11 @@ if is_bnb_4bit_available():
|
||||
warnings.warn(
|
||||
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
|
||||
)
|
||||
weight = self.base_layer.weight
|
||||
weight = self.get_base_layer().weight
|
||||
kwargs = weight.__dict__
|
||||
lora_data = self.get_delta_weight(active_adapter)
|
||||
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
|
||||
self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
|
||||
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
|
||||
weight.device
|
||||
)
|
||||
|
||||
@ -262,11 +275,11 @@ if is_bnb_4bit_available():
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer.forward(x, *args, **kwargs)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer.forward(x, *args, **kwargs)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self.base_layer.forward(x, *args, **kwargs)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
||||
# The reason is that in some cases, an error can occur that backprop
|
||||
# does not work on a manipulated view. This issue may be solved with
|
||||
@ -294,3 +307,7 @@ if is_bnb_4bit_available():
|
||||
result += output
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora." + rep
|
||||
|
@ -21,22 +21,21 @@ from peft.tuners.lora.layer import LoraLayer
|
||||
class QuantLinear(torch.nn.Module, LoraLayer):
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name,
|
||||
quant_linear_module,
|
||||
base_layer,
|
||||
adapter_name: str,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
torch.nn.Module.__init__(self)
|
||||
LoraLayer.__init__(
|
||||
self, in_features=quant_linear_module.infeatures, out_features=quant_linear_module.outfeatures
|
||||
)
|
||||
self.quant_linear_module = quant_linear_module
|
||||
self.weight = quant_linear_module.qweight
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, base_layer)
|
||||
|
||||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
||||
# for backwards compatibility
|
||||
self.quant_linear_module = base_layer
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# note: logic differs from default Linear because merging is not supported
|
||||
@ -65,6 +64,10 @@ class QuantLinear(torch.nn.Module, LoraLayer):
|
||||
result += output
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora." + rep
|
||||
|
||||
# TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102
|
||||
# def reset_lora_parameters(self, adapter_name):
|
||||
# if adapter_name in self.lora_A.keys():
|
||||
|
@ -15,21 +15,25 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils.other import transpose
|
||||
|
||||
|
||||
class LoraLayer(BaseTunerLayer):
|
||||
# List all names of layers that may contain adapter weights
|
||||
adapter_layer_names = ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"]
|
||||
# All names of layers that may contain (trainable) adapter weights
|
||||
adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B")
|
||||
# All names of other parameters that may contain adapter-related parameters
|
||||
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, **kwargs):
|
||||
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
|
||||
self.base_layer = base_layer
|
||||
self.r = {}
|
||||
self.lora_alpha = {}
|
||||
self.scaling = {}
|
||||
@ -42,21 +46,26 @@ class LoraLayer(BaseTunerLayer):
|
||||
# Mark the weight as unmerged
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
in_features, out_features = base_layer.in_channels, base_layer.out_channels
|
||||
elif isinstance(base_layer, nn.Embedding):
|
||||
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
|
||||
elif isinstance(base_layer, Conv1D):
|
||||
in_features, out_features = (
|
||||
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
|
||||
)
|
||||
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
|
||||
# QuantLinear
|
||||
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(base_layer)}")
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _init_empty_weights(self, cls, *args, **kwargs) -> None:
|
||||
# A helper method that allows to initialize the layer of the given class without spending time to initialize the
|
||||
# model weights. The implementation is inspired by
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.utils.skip_init.html but this function cannot be used
|
||||
# directly.
|
||||
# Instead of this approach, it would be possible to bypass the __init__ of the class but that runs the risk of
|
||||
# omitting important logic inside that __init__.
|
||||
kwargs = kwargs.copy()
|
||||
final_device = kwargs.pop("device", "cpu")
|
||||
cls.__init__(self, *args, device="meta", **kwargs)
|
||||
self.to_empty(device=final_device)
|
||||
|
||||
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
|
||||
if r <= 0:
|
||||
@ -77,7 +86,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
if init_lora_weights:
|
||||
self.reset_lora_parameters(adapter_name)
|
||||
|
||||
weight = getattr(self, "weight", None)
|
||||
weight = getattr(self.get_base_layer(), "weight", None)
|
||||
if weight is not None:
|
||||
# the layer is already completely initialized, this is an update
|
||||
if weight.dtype.is_floating_point or weight.dtype.is_complex:
|
||||
@ -98,20 +107,22 @@ class LoraLayer(BaseTunerLayer):
|
||||
|
||||
self.lora_dropout[adapter_name] = lora_dropout_layer
|
||||
# Actual trainable parameters
|
||||
base_layer = self.get_base_layer()
|
||||
if r > 0:
|
||||
kernel_size = self.kwargs["kernel_size"]
|
||||
stride = self.kwargs["stride"]
|
||||
padding = self.kwargs["padding"]
|
||||
kernel_size = base_layer.kernel_size
|
||||
stride = base_layer.stride
|
||||
padding = base_layer.padding
|
||||
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
|
||||
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
|
||||
self.scaling[adapter_name] = lora_alpha / r
|
||||
if init_lora_weights:
|
||||
self.reset_lora_parameters(adapter_name)
|
||||
|
||||
weight = getattr(self, "weight", None)
|
||||
weight = getattr(base_layer, "weight", None)
|
||||
if weight is not None:
|
||||
# the layer is already completely initialized, this is an update
|
||||
self.to(self.weight.device, dtype=weight.dtype)
|
||||
self.to(base_layer.weight.device, dtype=weight.dtype)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
|
||||
if r <= 0:
|
||||
@ -134,10 +145,12 @@ class LoraLayer(BaseTunerLayer):
|
||||
if init_lora_weights:
|
||||
self.reset_lora_parameters(adapter_name)
|
||||
|
||||
weight = getattr(self, "weight", None)
|
||||
base_layer = self.get_base_layer()
|
||||
weight = getattr(base_layer, "weight", None)
|
||||
if weight is not None:
|
||||
# the layer is already completely initialized, this is an update
|
||||
self.to(self.weight.device, dtype=weight.dtype)
|
||||
self.to(base_layer.weight.device, dtype=weight.dtype)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def reset_lora_parameters(self, adapter_name):
|
||||
if adapter_name in self.lora_A.keys():
|
||||
@ -186,37 +199,29 @@ class LoraLayer(BaseTunerLayer):
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Linear(nn.Linear, LoraLayer):
|
||||
class Linear(nn.Module, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name: str,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
is_target_conv_1d_layer: bool = False,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
# this gets the init from nn.Linear's super perspective, i.e.
|
||||
# nn.Module.__init__, which should always be called
|
||||
super(nn.Linear, self).__init__()
|
||||
# Note that we don't use self._init_empty_weights() for Linear because it is a bit slower and the benefit of
|
||||
# added robustness is not big enough for Linear.
|
||||
|
||||
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
|
||||
# Freezing the pre-trained weight matrix
|
||||
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, base_layer)
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def merge(self, safe_merge: bool = False) -> None:
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
@ -225,18 +230,26 @@ class Linear(nn.Linear, LoraLayer):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.lora_A.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weights = self.weight.data.clone()
|
||||
orig_weights = base_layer.weight.data.clone()
|
||||
orig_weights += self.get_delta_weight(active_adapter)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
@ -244,9 +257,9 @@ class Linear(nn.Linear, LoraLayer):
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
self.weight.data = orig_weights
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
self.weight.data += self.get_delta_weight(active_adapter)
|
||||
base_layer.weight.data += self.get_delta_weight(active_adapter)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
@ -256,7 +269,7 @@ class Linear(nn.Linear, LoraLayer):
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.lora_A.keys():
|
||||
self.weight.data -= self.get_delta_weight(active_adapter)
|
||||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
||||
|
||||
def get_delta_weight(self, adapter) -> torch.Tensor:
|
||||
"""
|
||||
@ -292,20 +305,17 @@ class Linear(nn.Linear, LoraLayer):
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _linear(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self._linear(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
@ -319,26 +329,30 @@ class Linear(nn.Linear, LoraLayer):
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora." + rep
|
||||
|
||||
class Embedding(nn.Embedding, LoraLayer):
|
||||
|
||||
class Embedding(nn.Module, LoraLayer):
|
||||
# LoRA implemented in a Embedding layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self._init_empty_weights(nn.Embedding, num_embeddings, embedding_dim, **kwargs)
|
||||
LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim)
|
||||
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, base_layer)
|
||||
|
||||
def merge(self, safe_merge: bool = False) -> None:
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
@ -347,18 +361,26 @@ class Embedding(nn.Embedding, LoraLayer):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.lora_embedding_A.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weights = self.weight.data.copy()
|
||||
orig_weights = base_layer.weight.data.copy()
|
||||
orig_weights += self.get_delta_weight(active_adapter)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
@ -366,9 +388,9 @@ class Embedding(nn.Embedding, LoraLayer):
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
self.weight.data = orig_weights
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
self.weight.data += self.get_delta_weight(active_adapter)
|
||||
base_layer.weight.data += self.get_delta_weight(active_adapter)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
@ -378,7 +400,7 @@ class Embedding(nn.Embedding, LoraLayer):
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.lora_embedding_A.keys():
|
||||
self.weight.data -= self.get_delta_weight(active_adapter)
|
||||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
||||
|
||||
def get_delta_weight(self, adapter) -> torch.Tensor:
|
||||
"""
|
||||
@ -414,28 +436,28 @@ class Embedding(nn.Embedding, LoraLayer):
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _embed(self, input: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = self.weight if weight is None else weight
|
||||
def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
base_layer = self.get_base_layer()
|
||||
return F.embedding(
|
||||
input,
|
||||
weight,
|
||||
padding_idx=self.padding_idx,
|
||||
max_norm=self.max_norm,
|
||||
norm_type=self.norm_type,
|
||||
scale_grad_by_freq=self.scale_grad_by_freq,
|
||||
sparse=self.sparse,
|
||||
padding_idx=base_layer.padding_idx,
|
||||
max_norm=base_layer.max_norm,
|
||||
norm_type=base_layer.norm_type,
|
||||
scale_grad_by_freq=base_layer.scale_grad_by_freq,
|
||||
sparse=base_layer.sparse,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
# TODO: no dtype conversion here, unlike in Linear, is that correct?
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self._embed(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self._embed(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self._embed(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_embedding_A:
|
||||
continue
|
||||
@ -447,38 +469,30 @@ class Embedding(nn.Embedding, LoraLayer):
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora." + rep
|
||||
|
||||
class Conv2d(nn.Conv2d, LoraLayer):
|
||||
|
||||
class Conv2d(nn.Module, LoraLayer):
|
||||
# Lora implemented in a conv2d layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self._init_empty_weights(nn.Conv2d, in_channels, out_channels, kernel_size, stride=stride, padding=padding)
|
||||
|
||||
LoraLayer.__init__(
|
||||
self,
|
||||
in_features=in_channels,
|
||||
out_features=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, base_layer)
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def merge(self, safe_merge: bool = False) -> None:
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights inside the base weights
|
||||
|
||||
@ -487,27 +501,35 @@ class Conv2d(nn.Conv2d, LoraLayer):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.lora_A.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weights = self.weight.data.copy()
|
||||
orig_weights = base_layer.weight.data.copy()
|
||||
orig_weights += self.get_delta_weight(active_adapter)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
self.weight.data = orig_weights
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
self.weight.data += self.get_delta_weight(active_adapter)
|
||||
base_layer.weight.data += self.get_delta_weight(active_adapter)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
@ -517,7 +539,7 @@ class Conv2d(nn.Conv2d, LoraLayer):
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.lora_A.keys():
|
||||
self.weight.data -= self.get_delta_weight(active_adapter)
|
||||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
||||
|
||||
def get_delta_weight(self, adapter) -> torch.Tensor:
|
||||
"""
|
||||
@ -543,7 +565,7 @@ class Conv2d(nn.Conv2d, LoraLayer):
|
||||
weight_B = weight_B.float()
|
||||
|
||||
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
|
||||
if self.weight.size()[2:4] == (1, 1):
|
||||
if self.get_base_layer().weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
|
||||
3
|
||||
@ -567,28 +589,17 @@ class Conv2d(nn.Conv2d, LoraLayer):
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _conv2d(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.conv2d(
|
||||
input,
|
||||
self.weight,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self._conv2d(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self._conv2d(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self._conv2d(x)
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
@ -601,3 +612,7 @@ class Conv2d(nn.Conv2d, LoraLayer):
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora." + rep
|
||||
|
@ -19,9 +19,9 @@ from dataclasses import asdict, replace
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
@ -107,6 +107,8 @@ class LoraModel(BaseTuner):
|
||||
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
|
||||
"""
|
||||
|
||||
prefix: str = "lora_"
|
||||
|
||||
def __init__(self, model, config, adapter_name) -> None:
|
||||
super().__init__(model, config, adapter_name)
|
||||
|
||||
@ -164,7 +166,7 @@ class LoraModel(BaseTuner):
|
||||
kwargs["gptq_quantization_config"] = quantization_config
|
||||
|
||||
# TODO: better deal with that
|
||||
if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d):
|
||||
if isinstance(target, Conv2d):
|
||||
target.update_layer_conv2d(
|
||||
adapter_name,
|
||||
r,
|
||||
@ -172,7 +174,7 @@ class LoraModel(BaseTuner):
|
||||
lora_config.lora_dropout,
|
||||
lora_config.init_lora_weights,
|
||||
)
|
||||
elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding):
|
||||
elif isinstance(target, Embedding):
|
||||
target.update_layer_embedding(
|
||||
adapter_name,
|
||||
r,
|
||||
@ -180,8 +182,7 @@ class LoraModel(BaseTuner):
|
||||
lora_config.lora_dropout,
|
||||
lora_config.init_lora_weights,
|
||||
)
|
||||
|
||||
elif isinstance(target, LoraLayer):
|
||||
elif isinstance(target, Linear):
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
@ -196,8 +197,7 @@ class LoraModel(BaseTuner):
|
||||
new_module.requires_grad_(False)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
||||
@staticmethod
|
||||
def _replace_module(parent, child_name, new_module, child):
|
||||
def _replace_module(self, parent, child_name, new_module, child):
|
||||
setattr(parent, child_name, new_module)
|
||||
# It's not necessary to set requires_grad here, as that is handled by
|
||||
# _mark_only_adapters_as_trainable
|
||||
@ -205,10 +205,7 @@ class LoraModel(BaseTuner):
|
||||
# child layer wraps the original module, unpack it
|
||||
if hasattr(child, "base_layer"):
|
||||
child = child.base_layer
|
||||
elif hasattr(child, "quant_linear_module"):
|
||||
child = child.quant_linear_module
|
||||
|
||||
# TODO: layers with base_layer don't need the weight to be copied, as they have a reference already
|
||||
if not hasattr(new_module, "base_layer"):
|
||||
new_module.weight = child.weight
|
||||
if hasattr(child, "bias"):
|
||||
@ -223,14 +220,13 @@ class LoraModel(BaseTuner):
|
||||
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "lora_" in name:
|
||||
module.to(child.weight.device)
|
||||
if "ranknum" in name:
|
||||
module.to(child.weight.device)
|
||||
if (self.prefix in name) or ("ranknum" in name):
|
||||
weight = child.qweight if hasattr(child, "qweight") else child.weight
|
||||
module.to(weight.device)
|
||||
|
||||
def _mark_only_adapters_as_trainable(self) -> None:
|
||||
for n, p in self.model.named_parameters():
|
||||
if "lora_" not in n:
|
||||
if self.prefix not in n:
|
||||
p.requires_grad = False
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
@ -256,9 +252,13 @@ class LoraModel(BaseTuner):
|
||||
|
||||
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
|
||||
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
|
||||
bias = kwargs.pop("bias", False)
|
||||
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
|
||||
eightbit_kwargs = kwargs.copy()
|
||||
eightbit_kwargs.update(
|
||||
{
|
||||
@ -268,8 +268,8 @@ class LoraModel(BaseTuner):
|
||||
"index": target.index,
|
||||
}
|
||||
)
|
||||
new_module = Linear8bitLt(adapter_name, target, **eightbit_kwargs)
|
||||
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
|
||||
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
|
||||
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
|
||||
fourbit_kwargs = kwargs.copy()
|
||||
fourbit_kwargs.update(
|
||||
{
|
||||
@ -278,47 +278,37 @@ class LoraModel(BaseTuner):
|
||||
"quant_type": target.weight.quant_type,
|
||||
}
|
||||
)
|
||||
new_module = Linear4bit(adapter_name, target, **fourbit_kwargs)
|
||||
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
|
||||
new_module = QuantLinear(adapter_name, target, **kwargs)
|
||||
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
|
||||
elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
|
||||
new_module = QuantLinear(target, adapter_name, **kwargs)
|
||||
target.weight = target.qweight
|
||||
elif isinstance(target, torch.nn.Embedding):
|
||||
elif isinstance(target_base_layer, torch.nn.Embedding):
|
||||
embedding_kwargs = kwargs.copy()
|
||||
embedding_kwargs.pop("fan_in_fan_out", None)
|
||||
in_features, out_features = target.num_embeddings, target.embedding_dim
|
||||
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
|
||||
elif isinstance(target, torch.nn.Conv2d):
|
||||
out_channels, in_channels = target.weight.size()[:2]
|
||||
kernel_size = target.weight.size()[2:]
|
||||
stride = target.stride
|
||||
padding = target.padding
|
||||
new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs)
|
||||
new_module = Embedding(target, adapter_name, **embedding_kwargs)
|
||||
elif isinstance(target_base_layer, torch.nn.Conv2d):
|
||||
new_module = Conv2d(target, adapter_name, **kwargs)
|
||||
elif isinstance(target_base_layer, torch.nn.Linear):
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||
new_module = Linear(target, adapter_name, **kwargs)
|
||||
elif isinstance(target_base_layer, Conv1D):
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
"Setting fan_in_fan_out to True."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
||||
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
|
||||
else:
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
in_features, out_features = target.in_features, target.out_features
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||
elif isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
|
||||
)
|
||||
kwargs["is_target_conv_1d_layer"] = True
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
"Setting fan_in_fan_out to True."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. Currently, only the following modules are supported: "
|
||||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
|
||||
)
|
||||
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. Currently, only the following modules are supported: "
|
||||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
|
||||
)
|
||||
|
||||
return new_module
|
||||
|
||||
@ -376,65 +366,31 @@ class LoraModel(BaseTuner):
|
||||
)
|
||||
return peft_config
|
||||
|
||||
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False, safe_merge: bool = False):
|
||||
def _unload_and_optionally_merge(
|
||||
self,
|
||||
merge=True,
|
||||
progressbar: bool = False,
|
||||
safe_merge: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
):
|
||||
if merge:
|
||||
if getattr(self.model, "quantization_method", None) == "gptq":
|
||||
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
|
||||
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||
desc = "Unloading " + ("and merging " if merge else "") + "model"
|
||||
for key in tqdm(key_list, disable=not progressbar, desc=desc):
|
||||
try:
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
if isinstance(target, LoraLayer):
|
||||
if isinstance(target, nn.Embedding):
|
||||
new_module = torch.nn.Embedding(target.in_features, target.out_features)
|
||||
elif isinstance(target, nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
target.in_channels,
|
||||
target.out_channels,
|
||||
kernel_size=target.kernel_size,
|
||||
stride=target.stride,
|
||||
padding=target.padding,
|
||||
dilation=target.dilation,
|
||||
)
|
||||
elif is_bnb_available() and isinstance(target, Linear8bitLt):
|
||||
bias = target.base_layer.bias is not None
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=target.base_layer.state.has_fp16_weights,
|
||||
memory_efficient_backward=target.base_layer.state.memory_efficient_backward,
|
||||
threshold=target.base_layer.state.threshold,
|
||||
index=target.base_layer.index,
|
||||
device=target.base_layer.weight.device,
|
||||
)
|
||||
elif is_bnb_4bit_available() and isinstance(target, Linear4bit):
|
||||
bias = target.base_layer.bias is not None
|
||||
new_module = bnb.nn.Linear4bit(
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
bias=bias,
|
||||
compute_dtype=target.base_layer.compute_dtype,
|
||||
compress_statistics=target.base_layer.weight.compress_statistics,
|
||||
quant_type=target.base_layer.weight.quant_type,
|
||||
device=target.base_layer.weight.device,
|
||||
)
|
||||
else:
|
||||
bias = target.bias is not None
|
||||
if getattr(target, "is_target_conv_1d_layer", False):
|
||||
new_module = Conv1D(target.out_features, target.in_features)
|
||||
else:
|
||||
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
|
||||
if merge:
|
||||
target.merge(safe_merge=safe_merge)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
if isinstance(target, ModulesToSaveWrapper):
|
||||
if hasattr(target, "base_layer"):
|
||||
if merge:
|
||||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
||||
elif isinstance(target, ModulesToSaveWrapper):
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
|
||||
|
||||
return self.model
|
||||
@ -536,7 +492,7 @@ class LoraModel(BaseTuner):
|
||||
# Do we really need that?
|
||||
_freeze_adapter(self.model, adapter_name)
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
|
||||
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(self.model, key)
|
||||
if isinstance(target, LoraLayer):
|
||||
@ -660,32 +616,20 @@ class LoraModel(BaseTuner):
|
||||
raise ValueError(f"Adapter {adapter_name} does not exist")
|
||||
del self.peft_config[adapter_name]
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
|
||||
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||
new_adapter = None
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(self.model, key)
|
||||
if isinstance(target, LoraLayer):
|
||||
for attr in [
|
||||
"r",
|
||||
"lora_alpha",
|
||||
"scaling",
|
||||
"lora_A",
|
||||
"lora_B",
|
||||
"lora_embedding_A",
|
||||
"lora_embedding_B",
|
||||
"lora_dropout",
|
||||
]:
|
||||
if adapter_name in getattr(target, attr):
|
||||
getattr(target, attr).pop(adapter_name)
|
||||
if adapter_name in target.active_adapters:
|
||||
resetting_active_adapter = (
|
||||
list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default"
|
||||
)
|
||||
warnings.warn(
|
||||
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
|
||||
)
|
||||
target.set_adapter(resetting_active_adapter)
|
||||
target.delete_adapter(adapter_name)
|
||||
if new_adapter is None:
|
||||
new_adapter = target.active_adapters[:]
|
||||
|
||||
def merge_and_unload(self, progressbar: bool = False, safe_merge: bool = False):
|
||||
self.active_adapter = new_adapter or []
|
||||
|
||||
def merge_and_unload(
|
||||
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
|
||||
):
|
||||
r"""
|
||||
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
|
||||
as a standalone model.
|
||||
@ -696,7 +640,9 @@ class LoraModel(BaseTuner):
|
||||
safe_merge (`bool`):
|
||||
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
||||
weights
|
||||
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
Example:
|
||||
|
||||
```py
|
||||
@ -709,7 +655,9 @@ class LoraModel(BaseTuner):
|
||||
>>> merged_model = model.merge_and_unload()
|
||||
```
|
||||
"""
|
||||
return self._unload_and_optionally_merge(progressbar=progressbar, safe_merge=safe_merge)
|
||||
return self._unload_and_optionally_merge(
|
||||
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
def unload(self):
|
||||
"""
|
||||
|
@ -13,12 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, Optional, Set, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -58,12 +56,15 @@ class LycorisConfig(PeftConfig):
|
||||
)
|
||||
|
||||
|
||||
class LycorisLayer(BaseTunerLayer, nn.Module):
|
||||
class LycorisLayer(BaseTunerLayer):
|
||||
r"""
|
||||
A base layer for LyCORIS like adapters
|
||||
"""
|
||||
# adapter_layer_names needs to be defined on the child class
|
||||
other_param_names = ("r", "alpha", "scaling", "rank_dropout", "module_dropout")
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, base_layer: nn.Module) -> None:
|
||||
self.base_layer = base_layer
|
||||
self.r = {}
|
||||
self.alpha = {}
|
||||
self.scaling = {}
|
||||
@ -91,56 +92,44 @@ class LycorisLayer(BaseTunerLayer, nn.Module):
|
||||
cls.__init__(self, *args, device="meta", **kwargs)
|
||||
self.to_empty(device=final_device)
|
||||
|
||||
def _op(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_adapter_parameters(self, adapter_name: str, r: int, **kwargs):
|
||||
...
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self._op(x, self.weight)
|
||||
elif self.merged:
|
||||
result = self._op(x, self.weight)
|
||||
else:
|
||||
# Get base weights
|
||||
weight = self.weight.data
|
||||
|
||||
# Execute all the adapters
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self._available_adapters:
|
||||
continue
|
||||
|
||||
module_dropout = self.module_dropout[active_adapter]
|
||||
|
||||
# Modify current execution weights
|
||||
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
|
||||
weight = weight + self.get_delta_weight(active_adapter)
|
||||
|
||||
# Perform actual operation
|
||||
result = self._op(x, weight)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
# TODO: refactor LoRA to use the same approach
|
||||
@abstractmethod
|
||||
def _get_delta_activations(self, adapter_name: str, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
"""Activations added on top of the base layer output (i.e. after the base layer forward pass)"""
|
||||
|
||||
@abstractmethod
|
||||
def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
|
||||
...
|
||||
|
||||
def merge(self) -> None:
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
if self.merged:
|
||||
warnings.warn(
|
||||
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
|
||||
f"You are now additionally merging {','.join(self.active_adapters)}."
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
if adapter_names is None:
|
||||
adapter_names = self.active_adapters
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self._available_adapters:
|
||||
self.weight.data += self.get_delta_weight(active_adapter)
|
||||
base_layer = self.get_base_layer()
|
||||
|
||||
if safe_merge:
|
||||
orig_weights = base_layer.weight.data
|
||||
orig_weights += self.get_delta_weight(active_adapter)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
base_layer.weight.data += self.get_delta_weight(active_adapter)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
@abstractmethod
|
||||
@ -170,7 +159,7 @@ class LycorisLayer(BaseTunerLayer, nn.Module):
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self._available_adapters:
|
||||
self.weight.data -= self.get_delta_weight(active_adapter)
|
||||
self.base_layer.weight.data -= self.get_delta_weight(active_adapter)
|
||||
|
||||
def unscale_layer(self, scale=None) -> None:
|
||||
for active_adapter in self.active_adapters:
|
||||
@ -209,6 +198,7 @@ class LycorisTuner(BaseTuner):
|
||||
def _check_target_module_exists(config, key):
|
||||
return check_target_module_exists(config, key)
|
||||
|
||||
@abstractmethod
|
||||
def _create_and_replace(
|
||||
self,
|
||||
config: LycorisConfig,
|
||||
@ -219,68 +209,47 @@ class LycorisTuner(BaseTuner):
|
||||
current_key,
|
||||
**optional_kwargs,
|
||||
):
|
||||
"""
|
||||
A private method to create and replace the target module with the adapter module.
|
||||
"""
|
||||
|
||||
# Regexp matching - Find key which matches current target_name in patterns provided
|
||||
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
|
||||
target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name)
|
||||
|
||||
kwargs = config.to_dict()
|
||||
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
|
||||
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
|
||||
|
||||
if isinstance(target, LycorisLayer):
|
||||
target.update_layer(adapter_name, **kwargs)
|
||||
else:
|
||||
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn.Module, **kwargs) -> LycorisLayer:
|
||||
# Find corresponding subtype of provided target module
|
||||
new_module_cls = None
|
||||
for subtype, target_cls in cls.layers_mapping.items():
|
||||
if isinstance(target, subtype):
|
||||
if (
|
||||
hasattr(target, "base_layer")
|
||||
and isinstance(target.get_base_layer(), subtype)
|
||||
and isinstance(target, BaseTunerLayer)
|
||||
):
|
||||
# nested tuner layers are allowed
|
||||
new_module_cls = target_cls
|
||||
break
|
||||
elif isinstance(target, subtype):
|
||||
new_module_cls = target_cls
|
||||
break
|
||||
|
||||
# We didn't find corresponding type, so adapter for this layer is not supported
|
||||
if new_module_cls is None:
|
||||
supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys())
|
||||
raise ValueError(
|
||||
f"Target module not found, currently only adapters for {', '.join([x.__name__ for x in cls.modules_mapping.keys()])} are supported"
|
||||
f"Target module of type {type(target)} not supported, "
|
||||
f"currently only adapters for {supported_modules} are supported"
|
||||
)
|
||||
|
||||
if isinstance(target, torch.nn.Conv2d):
|
||||
new_module = new_module_cls(
|
||||
target.in_channels,
|
||||
target.out_channels,
|
||||
target.weight.size()[2:],
|
||||
stride=target.stride,
|
||||
padding=target.padding,
|
||||
dilation=target.dilation,
|
||||
groups=target.groups,
|
||||
bias=target.bias is not None,
|
||||
padding_mode=target.padding_mode,
|
||||
device=target.weight.device,
|
||||
dtype=target.weight.dtype,
|
||||
adapter_name=adapter_name,
|
||||
**kwargs,
|
||||
)
|
||||
elif isinstance(target, torch.nn.Linear):
|
||||
new_module = new_module_cls(
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
bias=target.bias is not None,
|
||||
device=target.weight.device,
|
||||
dtype=target.weight.dtype,
|
||||
adapter_name=adapter_name,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if isinstance(target_base_layer, torch.nn.Conv2d):
|
||||
new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs)
|
||||
elif isinstance(target_base_layer, torch.nn.Linear):
|
||||
new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs)
|
||||
else:
|
||||
supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys())
|
||||
raise ValueError(
|
||||
"Target module not found, currently only adapters for nn.Linear and nn.Conv2d are supported"
|
||||
f"Target module of type {type(target)} not supported, "
|
||||
f"currently only adapters for {supported_modules} are supported"
|
||||
)
|
||||
|
||||
return new_module
|
||||
@ -300,12 +269,17 @@ class LycorisTuner(BaseTuner):
|
||||
setattr(parent, child_name, new_module)
|
||||
# It's not necessary to set requires_grad here, as that is handled by
|
||||
# _mark_only_adapters_as_trainable
|
||||
new_module.weight = child.weight
|
||||
if hasattr(child, "bias"):
|
||||
new_module.bias = child.bias
|
||||
|
||||
if not hasattr(new_module, "base_layer"):
|
||||
new_module.weight = child.weight
|
||||
if hasattr(child, "bias"):
|
||||
new_module.bias = child.bias
|
||||
|
||||
if getattr(child, "state", None) is not None:
|
||||
new_module.state = child.state
|
||||
if hasattr(new_module, "base_layer"):
|
||||
new_module.base_layer.state = child.state
|
||||
else:
|
||||
new_module.state = child.state
|
||||
new_module.to(child.weight.device)
|
||||
|
||||
# dispatch to correct device
|
||||
@ -318,46 +292,31 @@ class LycorisTuner(BaseTuner):
|
||||
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
||||
module.enable_adapters(enabled)
|
||||
|
||||
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
|
||||
def _unload_and_optionally_merge(
|
||||
self,
|
||||
merge: bool = True,
|
||||
progressbar: bool = False,
|
||||
safe_merge: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
):
|
||||
if merge:
|
||||
if getattr(self.model, "quantization_method", None) == "gptq":
|
||||
raise ValueError("Cannot merge LOHA layers when the model is gptq quantized")
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if "hada" not in key]
|
||||
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||
desc = "Unloading " + ("and merging " if merge else "") + "model"
|
||||
for key in tqdm(key_list, disable=not progressbar, desc=desc):
|
||||
try:
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
if isinstance(target, LycorisLayer):
|
||||
if isinstance(target, nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
target.in_channels,
|
||||
target.out_channels,
|
||||
kernel_size=target.kernel_size,
|
||||
stride=target.stride,
|
||||
padding=target.padding,
|
||||
dilation=target.dilation,
|
||||
)
|
||||
elif isinstance(target, nn.Linear):
|
||||
bias = target.bias is not None
|
||||
new_module = torch.nn.Linear(
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
bias=bias,
|
||||
device=target.weight.device,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot convert current module to torch module, currently only adapters for nn.Linear and nn.Conv2d are supported"
|
||||
)
|
||||
if merge:
|
||||
target.merge()
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
if isinstance(target, ModulesToSaveWrapper):
|
||||
if hasattr(target, "base_layer"):
|
||||
if merge:
|
||||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
||||
elif isinstance(target, ModulesToSaveWrapper):
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
|
||||
|
||||
return self.model
|
||||
@ -368,8 +327,34 @@ class LycorisTuner(BaseTuner):
|
||||
def disable_adapter_layers(self):
|
||||
self._set_adapter_layers(enabled=False)
|
||||
|
||||
def merge_and_unload(self, progressbar: bool = False):
|
||||
return self._unload_and_optionally_merge(progressbar=progressbar)
|
||||
def merge_and_unload(
|
||||
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
|
||||
):
|
||||
r"""
|
||||
This method merges the adapter layers into the base model. This is needed if someone wants to use the base
|
||||
model as a standalone model.
|
||||
|
||||
Args:
|
||||
progressbar (`bool`):
|
||||
whether to show a progressbar indicating the unload and merge process
|
||||
safe_merge (`bool`):
|
||||
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
||||
weights
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
|
||||
"""
|
||||
return self._unload_and_optionally_merge(
|
||||
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
def unload(self):
|
||||
"""
|
||||
Gets back the base model by removing all the lora modules without merging. This gives back the original base
|
||||
model.
|
||||
"""
|
||||
return self._unload_and_optionally_merge(merge=False)
|
||||
|
||||
def set_adapter(self, adapter_name):
|
||||
for module in self.model.modules():
|
||||
@ -391,17 +376,12 @@ class LycorisTuner(BaseTuner):
|
||||
del self.peft_config[adapter_name]
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||
new_adapter = None
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(self.model, key)
|
||||
if isinstance(target, LycorisLayer):
|
||||
for attr in target.adapter_layer_names:
|
||||
if adapter_name in getattr(target, attr):
|
||||
getattr(target, attr).pop(adapter_name)
|
||||
if adapter_name in target.active_adapters:
|
||||
resetting_active_adapter = (
|
||||
list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default"
|
||||
)
|
||||
warnings.warn(
|
||||
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
|
||||
)
|
||||
target.set_adapter(resetting_active_adapter)
|
||||
target.delete_adapter(adapter_name)
|
||||
if new_adapter is None:
|
||||
new_adapter = target.active_adapters[:]
|
||||
|
||||
self.active_adapter = new_adapter or []
|
||||
|
@ -104,7 +104,7 @@ class PromptEncoder(torch.nn.Module):
|
||||
encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers
|
||||
if config.encoder_num_layers != encoder_num_layers_default:
|
||||
warnings.warn(
|
||||
f"for {self.encoder_type}, the argument `encoder_num_layers` is ignored. "
|
||||
f"for {self.encoder_type.value}, the argument `encoder_num_layers` is ignored. "
|
||||
f"Exactly {encoder_num_layers_default} MLP layers are used."
|
||||
)
|
||||
layers = [
|
||||
|
@ -37,6 +37,9 @@ class PromptTuningConfig(PromptLearningConfig):
|
||||
The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`.
|
||||
tokenizer_name_or_path (`str`, *optional*):
|
||||
The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`.
|
||||
tokenizer_kwargs (`dict`, *optional*):
|
||||
The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if `prompt_tuning_init` is
|
||||
`TEXT`.
|
||||
"""
|
||||
|
||||
prompt_tuning_init: Union[PromptTuningInit, str] = field(
|
||||
@ -56,5 +59,20 @@ class PromptTuningConfig(PromptLearningConfig):
|
||||
},
|
||||
)
|
||||
|
||||
tokenizer_kwargs: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if prompt_tuning_init is "
|
||||
"`TEXT`"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.peft_type = PeftType.PROMPT_TUNING
|
||||
|
||||
if self.tokenizer_kwargs and (self.prompt_tuning_init != PromptTuningInit.TEXT):
|
||||
raise ValueError(
|
||||
f"tokenizer_kwargs only valid when using prompt_tuning_init='{PromptTuningInit.TEXT.value}'."
|
||||
)
|
||||
|
@ -66,7 +66,8 @@ class PromptEmbedding(torch.nn.Module):
|
||||
if config.prompt_tuning_init == PromptTuningInit.TEXT:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
|
||||
tokenizer_kwargs = config.tokenizer_kwargs or {}
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, **tokenizer_kwargs)
|
||||
init_text = config.prompt_tuning_init_text
|
||||
init_token_ids = tokenizer(init_text)["input_ids"]
|
||||
# Trim or iterate until num_text_tokens matches total_virtual_tokens
|
||||
@ -77,8 +78,9 @@ class PromptEmbedding(torch.nn.Module):
|
||||
num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
|
||||
init_token_ids = init_token_ids * num_reps
|
||||
init_token_ids = init_token_ids[:total_virtual_tokens]
|
||||
init_token_ids = torch.LongTensor(init_token_ids).to(word_embeddings.weight.device)
|
||||
|
||||
word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
|
||||
word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
|
||||
word_embedding_weights = word_embedding_weights.to(torch.float32)
|
||||
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
|
||||
|
||||
|
@ -16,15 +16,17 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from peft.utils import COMMON_LAYERS_PATTERN
|
||||
|
||||
from ..config import PeftConfig
|
||||
from ..utils import _get_submodules
|
||||
from ..utils import ModulesToSaveWrapper, _get_submodules
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -210,6 +212,9 @@ class BaseTuner(nn.Module, ABC):
|
||||
is_target_modules_in_base_model = False
|
||||
key_list = [key for key, _ in model.named_modules()]
|
||||
|
||||
_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
|
||||
_has_modules_to_save = False
|
||||
|
||||
model_config = getattr(model, "config", {"model_type": "custom"})
|
||||
if hasattr(model_config, "to_dict"):
|
||||
model_config = model_config.to_dict()
|
||||
@ -217,6 +222,22 @@ class BaseTuner(nn.Module, ABC):
|
||||
peft_config = self._prepare_adapter_config(peft_config, model_config)
|
||||
|
||||
for key in key_list:
|
||||
# Check for modules_to_save in case
|
||||
if _check_for_modules_to_save and any(
|
||||
key.endswith(f"{module_to_save}") for module_to_save in peft_config.modules_to_save
|
||||
):
|
||||
# Optionally set the modules to save
|
||||
parent, target, target_name = _get_submodules(model, key)
|
||||
|
||||
if not isinstance(target, ModulesToSaveWrapper):
|
||||
new_module = ModulesToSaveWrapper(target, adapter_name)
|
||||
setattr(parent, target_name, new_module)
|
||||
else:
|
||||
target.update(adapter_name)
|
||||
|
||||
_has_modules_to_save = True
|
||||
continue
|
||||
|
||||
if not self._check_target_module_exists(peft_config, key):
|
||||
continue
|
||||
|
||||
@ -243,6 +264,12 @@ class BaseTuner(nn.Module, ABC):
|
||||
if adapter_name in n:
|
||||
p.requires_grad = False
|
||||
|
||||
if _has_modules_to_save:
|
||||
if not hasattr(model, "modules_to_save"):
|
||||
model.modules_to_save = set(peft_config.modules_to_save)
|
||||
else:
|
||||
model.modules_to_save.update(set(peft_config.modules_to_save))
|
||||
|
||||
def merge_adapter(self):
|
||||
"""
|
||||
This method merges the LoRa layers into the base model.
|
||||
@ -272,8 +299,10 @@ class BaseTunerLayer(ABC):
|
||||
"""
|
||||
active_adapter = None
|
||||
|
||||
# List all names of layers that may contain adapter weights
|
||||
adapter_layer_names: list[str] = []
|
||||
# All names of layers that may contain adapter (trainable) weights
|
||||
adapter_layer_names: tuple[str] = ()
|
||||
# All names of other parameters that may contain adapter-related parameters
|
||||
other_param_names: tuple[str] = ()
|
||||
|
||||
# indicates whether all adapters should be disabled
|
||||
_disable_adapters: bool = False
|
||||
@ -284,6 +313,34 @@ class BaseTunerLayer(ABC):
|
||||
# List all merged adapters
|
||||
merged_adapters: list[str] = []
|
||||
|
||||
def get_base_layer(self) -> nn.Module:
|
||||
"""
|
||||
(Recursively) get the base_layer.
|
||||
|
||||
This is necessary for the case that the tuner layer wraps another tuner layer.
|
||||
|
||||
"""
|
||||
base_layer = self
|
||||
while hasattr(base_layer, "base_layer"):
|
||||
base_layer = base_layer.base_layer
|
||||
return base_layer
|
||||
|
||||
@property
|
||||
def weight(self) -> torch.Tensor:
|
||||
# This is required for some transformers code, e.g. for T5, weight is accessed as:
|
||||
# self.wo.weight
|
||||
# where "wo" is the adapter layer.
|
||||
# https://github.com/huggingface/transformers/blob/78f6ed6c70b29c1560780e3869a7ad4c6b3d2710/src/transformers
|
||||
# /models/t5/modeling_t5.py#L292
|
||||
base_layer = self.get_base_layer()
|
||||
if hasattr(base_layer, "qweight"):
|
||||
# QuantLinear
|
||||
weight = base_layer.qweight
|
||||
else:
|
||||
# Other layers
|
||||
weight = base_layer.weight
|
||||
return weight
|
||||
|
||||
def merge(self, *args) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -351,6 +408,54 @@ class BaseTunerLayer(ABC):
|
||||
|
||||
self._active_adapter = adapter_names
|
||||
|
||||
def _all_available_adapter_names(self) -> list[str]:
|
||||
"""Return a sorted list of all available adapter names"""
|
||||
adapter_names = set()
|
||||
for name in self.adapter_layer_names + self.other_param_names:
|
||||
# we check each possible attribute and if it's a dict or ModuleDict, we assume that the keys are the adapter
|
||||
# names
|
||||
attr = getattr(self, name)
|
||||
if hasattr(attr, "keys"):
|
||||
adapter_names.update(attr.keys())
|
||||
return sorted(adapter_names)
|
||||
|
||||
def delete_adapter(self, adapter_name: str) -> None:
|
||||
"""
|
||||
Delete an adapter from the layer
|
||||
|
||||
This should be called on all adapter layers, or else we will get an inconsistent state.
|
||||
|
||||
This method will also set a new active adapter if the deleted adapter was an active adapter. It is important
|
||||
that the new adapter is chosen in a deterministic way, so that the same adapter is chosen on all layers.
|
||||
|
||||
Args:
|
||||
adapter_name (`str`): The name of the adapter to delete
|
||||
|
||||
"""
|
||||
for attr in self.adapter_layer_names + self.other_param_names:
|
||||
if adapter_name in getattr(self, attr):
|
||||
del getattr(self, attr)[adapter_name]
|
||||
|
||||
if adapter_name in self.active_adapters:
|
||||
# choose a new active adapter
|
||||
active_adapters = self.active_adapters[:]
|
||||
active_adapters.remove(adapter_name)
|
||||
if active_adapters:
|
||||
self.set_adapter(active_adapters)
|
||||
else:
|
||||
# no active adapters left, set a new default adapter
|
||||
# here we get the list of all adapters existing adapter names and choose the first one
|
||||
remaining_adapters = self._all_available_adapter_names()
|
||||
if not remaining_adapters:
|
||||
self.set_adapter([])
|
||||
else:
|
||||
new_active_adapter = remaining_adapters[0]
|
||||
warnings.warn(
|
||||
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to "
|
||||
f"{new_active_adapter}."
|
||||
)
|
||||
self.set_adapter(remaining_adapters[0])
|
||||
|
||||
|
||||
def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
|
||||
"""A helper method to check if the passed module's key name matches any of the target modules in the adapter_config.
|
||||
|
@ -45,6 +45,6 @@ from .other import (
|
||||
infer_device,
|
||||
get_auto_gptq_quant_linear,
|
||||
get_quantization_config,
|
||||
id_tensor_storage,
|
||||
)
|
||||
from .hub_utils import hub_file_exists
|
||||
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights
|
||||
|
@ -1,29 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
|
||||
def hub_file_exists(repo_id: str, filename: str, revision: str = None, repo_type: str = None) -> bool:
|
||||
r"""
|
||||
Checks if a file exists in a remote Hub repository.
|
||||
"""
|
||||
url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision)
|
||||
try:
|
||||
get_hf_file_metadata(url)
|
||||
return True
|
||||
except EntryNotFoundError:
|
||||
return False
|
@ -15,14 +15,15 @@
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.utils import is_npu_available, is_xpu_available
|
||||
from safetensors.torch import storage_ptr, storage_size
|
||||
|
||||
from ..import_utils import is_auto_gptq_available
|
||||
from ..import_utils import is_auto_gptq_available, is_torch_tpu_available
|
||||
|
||||
|
||||
# Get current device name based on available devices
|
||||
@ -276,8 +277,22 @@ def _set_trainable(model, adapter_name):
|
||||
|
||||
|
||||
def _set_adapter(model, adapter_name):
|
||||
def check_adapter_name(adapter_name):
|
||||
if isinstance(adapter_name, str):
|
||||
return adapter_name
|
||||
|
||||
# adapter_name is a list of str
|
||||
if len(adapter_name) > 1:
|
||||
raise ValueError("Only one adapter can be set at a time for modules_to_save")
|
||||
elif len(adapter_name) == 0:
|
||||
raise ValueError("Please specify at least one adapter to set")
|
||||
adapter_name = adapter_name[0]
|
||||
return adapter_name
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, ModulesToSaveWrapper):
|
||||
# only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care
|
||||
adapter_name = check_adapter_name(adapter_name)
|
||||
module.set_adapter(adapter_name)
|
||||
|
||||
|
||||
@ -412,33 +427,57 @@ def get_auto_gptq_quant_linear(gptq_quantization_config):
|
||||
"""
|
||||
Get the right AutoGPTQQuantLinear class based on the quantization config file
|
||||
"""
|
||||
if is_auto_gptq_available():
|
||||
if gptq_quantization_config is not None and is_auto_gptq_available():
|
||||
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
|
||||
|
||||
if gptq_quantization_config is not None:
|
||||
desc_act = gptq_quantization_config.desc_act
|
||||
group_size = gptq_quantization_config.group_size
|
||||
bits = gptq_quantization_config.bits
|
||||
if hasattr(gptq_quantization_config, "use_exllama"):
|
||||
use_exllama = gptq_quantization_config.use_exllama
|
||||
else:
|
||||
use_exllama = not gptq_quantization_config.disable_exllama
|
||||
if hasattr(gptq_quantization_config, "exllama_config"):
|
||||
exllama_version = gptq_quantization_config.exllama_config["version"]
|
||||
else:
|
||||
exllama_version = 1
|
||||
AutoGPTQQuantLinear = dynamically_import_QuantLinear(
|
||||
use_triton=False,
|
||||
desc_act=desc_act,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
disable_exllama=not (use_exllama and exllama_version == 1),
|
||||
disable_exllamav2=not (use_exllama and exllama_version == 2),
|
||||
)
|
||||
return AutoGPTQQuantLinear
|
||||
desc_act = gptq_quantization_config.desc_act
|
||||
group_size = gptq_quantization_config.group_size
|
||||
bits = gptq_quantization_config.bits
|
||||
if hasattr(gptq_quantization_config, "use_exllama"):
|
||||
use_exllama = gptq_quantization_config.use_exllama
|
||||
else:
|
||||
use_exllama = not gptq_quantization_config.disable_exllama
|
||||
if hasattr(gptq_quantization_config, "exllama_config"):
|
||||
exllama_version = gptq_quantization_config.exllama_config["version"]
|
||||
else:
|
||||
exllama_version = 1
|
||||
AutoGPTQQuantLinear = dynamically_import_QuantLinear(
|
||||
use_triton=False,
|
||||
desc_act=desc_act,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
disable_exllama=not (use_exllama and exllama_version == 1),
|
||||
disable_exllamav2=not (use_exllama and exllama_version == 2),
|
||||
)
|
||||
return AutoGPTQQuantLinear
|
||||
return None
|
||||
|
||||
|
||||
def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
|
||||
"""
|
||||
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
|
||||
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
|
||||
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
||||
non-overlapping lifetimes may have the same id.
|
||||
|
||||
This method is the exact same copy of
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added
|
||||
it here manually to avoid import issue with old versions of transformers.
|
||||
"""
|
||||
if tensor.device.type == "xla" and is_torch_tpu_available():
|
||||
# NOTE: xla tensors dont have storage
|
||||
# use some other unique id to distinguish.
|
||||
# this is a XLA tensor, it must be created using torch_xla's
|
||||
# device. So the following import is safe:
|
||||
import torch_xla
|
||||
|
||||
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
|
||||
else:
|
||||
unique_id = storage_ptr(tensor)
|
||||
|
||||
return tensor.device, unique_id, storage_size(tensor)
|
||||
|
||||
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
|
||||
"t5": ["q", "v"],
|
||||
"mt5": ["q", "v"],
|
||||
|
@ -16,11 +16,10 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import file_exists, hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
from .hub_utils import hub_file_exists
|
||||
from .other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device
|
||||
from .peft_types import PeftType
|
||||
|
||||
@ -194,9 +193,9 @@ def load_peft_weights(model_id: str, device: Optional[str] = None, **hf_hub_down
|
||||
filename = os.path.join(path, WEIGHTS_NAME)
|
||||
use_safetensors = False
|
||||
else:
|
||||
has_remote_safetensors_file = hub_file_exists(
|
||||
model_id,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
has_remote_safetensors_file = file_exists(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_WEIGHTS_NAME,
|
||||
revision=hf_hub_download_kwargs.get("revision", None),
|
||||
repo_type=hf_hub_download_kwargs.get("repo_type", None),
|
||||
)
|
||||
|
@ -115,6 +115,51 @@ class AdaptionPromptTester(TestCase, PeftCommonTester):
|
||||
|
||||
self.assertTrue(dummy_output.requires_grad)
|
||||
|
||||
def test_save_pretrained_regression(self) -> None:
|
||||
seed = 420
|
||||
torch.manual_seed(seed)
|
||||
model = LlamaForCausalLM(self._create_test_llama_config())
|
||||
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, config)
|
||||
model = model.to(self.torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model.save_pretrained(tmp_dirname, safe_serialization=False)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
|
||||
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
|
||||
|
||||
# check if the state dicts are equal
|
||||
state_dict = get_peft_model_state_dict(model)
|
||||
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
|
||||
|
||||
# check if same keys
|
||||
self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys())
|
||||
|
||||
# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
|
||||
self.assertEqual(len(list(state_dict.keys())), 4)
|
||||
|
||||
# check if tensors equal
|
||||
for key in state_dict.keys():
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
|
||||
)
|
||||
)
|
||||
|
||||
# check if `adapter_model.bin` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
|
||||
|
||||
# check if `adapter_config.json` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
|
||||
|
||||
# check if `model.safetensors` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
|
||||
|
||||
# check if `config.json` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
|
||||
|
||||
def test_save_pretrained(self) -> None:
|
||||
seed = 420
|
||||
torch.manual_seed(seed)
|
||||
@ -149,13 +194,13 @@ class AdaptionPromptTester(TestCase, PeftCommonTester):
|
||||
)
|
||||
|
||||
# check if `adapter_model.bin` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")))
|
||||
|
||||
# check if `adapter_config.json` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
|
||||
|
||||
# check if `pytorch_model.bin` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))
|
||||
# check if `model.safetensors` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
|
||||
|
||||
# check if `config.json` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
|
||||
@ -199,13 +244,13 @@ class AdaptionPromptTester(TestCase, PeftCommonTester):
|
||||
)
|
||||
|
||||
# check if `adapter_model.bin` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")))
|
||||
|
||||
# check if `adapter_config.json` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
|
||||
|
||||
# check if `pytorch_model.bin` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))
|
||||
# check if `model.safetensors` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
|
||||
|
||||
# check if `config.json` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from peft import AdaLoraConfig
|
||||
from peft import AdaLoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model
|
||||
|
||||
from .testing_common import PeftCommonTester, PeftTestConfigManager
|
||||
|
||||
@ -76,14 +77,77 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
|
||||
def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_prepare_for_training(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_prompt_tuning_text_prepare_for_training(self, test_name, model_id, config_cls, config_kwargs):
|
||||
# Test that prompt tuning works with text init
|
||||
if config_cls != PromptTuningConfig:
|
||||
return
|
||||
|
||||
config_kwargs = config_kwargs.copy()
|
||||
config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT
|
||||
config_kwargs["prompt_tuning_init_text"] = "This is a test prompt."
|
||||
config_kwargs["tokenizer_name_or_path"] = model_id
|
||||
self._test_prepare_for_training(model_id, config_cls, config_kwargs)
|
||||
|
||||
def test_prompt_tuning_text_tokenizer_kwargs(self):
|
||||
# Allow users to pass additional arguments to Tokenizer.from_pretrained
|
||||
# Fix for #1032
|
||||
mock = Mock()
|
||||
orig_from_pretrained = AutoTokenizer.from_pretrained
|
||||
|
||||
def mock_autotokenizer_from_pretrained(*args, **kwargs):
|
||||
mock(*args, **kwargs)
|
||||
return orig_from_pretrained(config.tokenizer_name_or_path)
|
||||
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
config = PromptTuningConfig(
|
||||
base_model_name_or_path=model_id,
|
||||
tokenizer_name_or_path=model_id,
|
||||
num_virtual_tokens=10,
|
||||
prompt_tuning_init=PromptTuningInit.TEXT,
|
||||
task_type="CAUSAL_LM",
|
||||
prompt_tuning_init_text="This is a test prompt.",
|
||||
tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"},
|
||||
)
|
||||
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
|
||||
with patch("transformers.AutoTokenizer.from_pretrained", mock_autotokenizer_from_pretrained):
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
expected_call = call(model_id, trust_remote_code=True, foo="bar")
|
||||
self.assertEqual(mock.call_args, expected_call)
|
||||
|
||||
def test_prompt_tuning_config_invalid_args(self):
|
||||
# Raise an error when tokenizer_kwargs is used with prompt_tuning_init!='TEXT', because this argument has no
|
||||
# function in that case
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
msg = "tokenizer_kwargs only valid when using prompt_tuning_init='TEXT'."
|
||||
with self.assertRaisesRegex(ValueError, expected_regex=msg):
|
||||
PromptTuningConfig(
|
||||
base_model_name_or_path=model_id,
|
||||
tokenizer_name_or_path=model_id,
|
||||
num_virtual_tokens=10,
|
||||
task_type="CAUSAL_LM",
|
||||
prompt_tuning_init_text="This is a test prompt.",
|
||||
prompt_tuning_init=PromptTuningInit.RANDOM, # <= should not be used together with tokenizer_kwargs
|
||||
tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"},
|
||||
)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
|
||||
@ -101,6 +165,19 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
|
||||
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_merge_layers(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(
|
||||
PeftTestConfigManager.get_grid_parameters(
|
||||
{
|
||||
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
|
||||
"lora_kwargs": {"init_lora_weights": [False]},
|
||||
"ia3_kwargs": {"init_ia3_weights": [False]},
|
||||
"task_type": "CAUSAL_LM",
|
||||
},
|
||||
)
|
||||
)
|
||||
def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_merge_layers_multi(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(
|
||||
PeftTestConfigManager.get_grid_parameters(
|
||||
{
|
||||
@ -154,6 +231,10 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
|
||||
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_delete_adapter(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
|
||||
@ -164,6 +245,7 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
|
||||
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
|
||||
"lora_kwargs": {"init_lora_weights": [False]},
|
||||
"adalora_kwargs": {"init_lora_weights": [False]},
|
||||
"ia3_kwargs": {"init_ia3_weights": [False]},
|
||||
"task_type": "CAUSAL_LM",
|
||||
},
|
||||
filter_params_func=skip_adalora_and_gpt2,
|
||||
|
@ -70,10 +70,18 @@ class PeftEncoderDecoderModelTester(unittest.TestCase, PeftCommonTester):
|
||||
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
|
||||
@ -128,6 +136,10 @@ class PeftEncoderDecoderModelTester(unittest.TestCase, PeftCommonTester):
|
||||
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_delete_adapter(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
|
||||
|
@ -146,12 +146,17 @@ class PeftFeatureExtractionModelTester(unittest.TestCase, PeftCommonTester):
|
||||
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_delete_adapter(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
|
||||
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
|
||||
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
|
||||
|
||||
@parameterized.expand(
|
||||
PeftTestConfigManager.get_grid_parameters(
|
||||
{
|
||||
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST,
|
||||
"lora_kwargs": {"init_lora_weights": [False]},
|
||||
"adalora_kwargs": {"init_lora_weights": [False]},
|
||||
"ia3_kwargs": {"init_ia3_weights": [False]},
|
||||
"task_type": "FEATURE_EXTRACTION",
|
||||
},
|
||||
)
|
||||
|
@ -44,6 +44,7 @@ from peft import (
|
||||
prepare_model_for_int8_training,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME
|
||||
|
||||
from .testing_utils import (
|
||||
require_auto_gptq,
|
||||
@ -124,6 +125,14 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def _check_inference_finite(self, model, batch):
|
||||
# try inference without Trainer class
|
||||
training = model.training
|
||||
model.eval()
|
||||
output = model(**batch.to(model.device))
|
||||
self.assertTrue(torch.isfinite(output.logits).all())
|
||||
model.train(training)
|
||||
|
||||
@pytest.mark.single_gpu_tests
|
||||
def test_causal_lm_training(self):
|
||||
r"""
|
||||
@ -177,7 +186,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -235,7 +244,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -296,7 +305,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -334,6 +343,8 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
|
||||
data = load_dataset("ybelkada/english_quotes_copy")
|
||||
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
|
||||
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
|
||||
self._check_inference_finite(model, batch)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = Trainer(
|
||||
@ -357,7 +368,70 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_torch_gpu
|
||||
def test_8bit_adalora_causalLM(self):
|
||||
r"""
|
||||
Tests the 8bit training with adalora
|
||||
"""
|
||||
model_id = "facebook/opt-350m"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
peft_config = AdaLoraConfig(
|
||||
init_r=6,
|
||||
target_r=4,
|
||||
tinit=50,
|
||||
tfinal=100,
|
||||
deltaT=5,
|
||||
beta1=0.3,
|
||||
beta2=0.3,
|
||||
orth_reg_weight=0.2,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
data = load_dataset("ybelkada/english_quotes_copy")
|
||||
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
|
||||
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
|
||||
self._check_inference_finite(model, batch)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=data["train"],
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_steps=2,
|
||||
max_steps=3,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=1,
|
||||
output_dir=tmp_dir,
|
||||
),
|
||||
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
trainer.train()
|
||||
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -421,7 +495,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -481,7 +555,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -542,7 +616,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -640,7 +714,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -670,6 +744,14 @@ class PeftGPTQGPUTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _check_inference_finite(self, model, batch):
|
||||
# try inference without Trainer class
|
||||
training = model.training
|
||||
model.eval()
|
||||
output = model(**batch.to(model.device))
|
||||
self.assertTrue(torch.isfinite(output.logits).all())
|
||||
model.train(training)
|
||||
|
||||
@pytest.mark.single_gpu_tests
|
||||
def test_causal_lm_training(self):
|
||||
r"""
|
||||
@ -719,7 +801,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -737,6 +819,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
peft_config = AdaLoraConfig(
|
||||
@ -758,6 +841,8 @@ class PeftGPTQGPUTests(unittest.TestCase):
|
||||
|
||||
data = load_dataset("ybelkada/english_quotes_copy")
|
||||
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
|
||||
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
|
||||
self._check_inference_finite(model, batch)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = Trainer(
|
||||
@ -781,7 +866,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
@ -844,7 +929,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))
|
||||
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
|
||||
|
||||
# assert loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
@ -19,6 +19,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model
|
||||
from peft.utils import ModulesToSaveWrapper
|
||||
|
||||
|
||||
class DummyModel(torch.nn.Module):
|
||||
@ -63,3 +64,28 @@ class TestPeft(unittest.TestCase):
|
||||
|
||||
for key in peft_state_dict.keys():
|
||||
self.assertTrue("lora" in key)
|
||||
|
||||
def test_modules_to_save(self):
|
||||
self.model = DummyModel()
|
||||
|
||||
lora_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=64,
|
||||
bias="none",
|
||||
target_modules=["linear"],
|
||||
modules_to_save=["embedding"],
|
||||
)
|
||||
|
||||
self.model = inject_adapter_in_model(lora_config, self.model)
|
||||
|
||||
for name, module in self.model.named_modules():
|
||||
if name == "linear":
|
||||
self.assertTrue(hasattr(module, "lora_A"))
|
||||
self.assertTrue(hasattr(module, "lora_B"))
|
||||
elif name == "embedding":
|
||||
self.assertTrue(isinstance(module, ModulesToSaveWrapper))
|
||||
|
||||
state_dict = get_peft_model_state_dict(self.model)
|
||||
|
||||
self.assertTrue("embedding.weight" in state_dict.keys())
|
||||
|
@ -145,7 +145,52 @@ class MultiTaskPromptTuningTester(TestCase, PeftCommonTester):
|
||||
)
|
||||
)
|
||||
|
||||
# check if `adapter_model.bin` is present
|
||||
# check if `adapter_model.safetensors` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")))
|
||||
|
||||
# check if `adapter_config.json` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
|
||||
|
||||
# check if `pytorch_model.bin` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))
|
||||
|
||||
# check if `config.json` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
|
||||
|
||||
def test_save_pretrained_regression(self) -> None:
|
||||
seed = 420
|
||||
torch.manual_seed(seed)
|
||||
model = LlamaForCausalLM(self._create_test_llama_config())
|
||||
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
|
||||
model = model.to(self.torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model.save_pretrained(tmp_dirname, safe_serialization=False)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
|
||||
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
|
||||
|
||||
# check if the state dicts are equal
|
||||
state_dict = get_peft_model_state_dict(model)
|
||||
|
||||
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
|
||||
|
||||
# check if same keys
|
||||
self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys())
|
||||
|
||||
# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
|
||||
self.assertEqual(len(list(state_dict.keys())), 3)
|
||||
|
||||
# check if tensors equal
|
||||
for key in state_dict.keys():
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
|
||||
)
|
||||
)
|
||||
|
||||
# check if `adapter_model.bin` is present for regression
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
|
||||
|
||||
# check if `adapter_config.json` is present
|
||||
|
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
@ -29,6 +30,7 @@ from peft import (
|
||||
IA3Config,
|
||||
LoraConfig,
|
||||
PeftModel,
|
||||
PeftType,
|
||||
PrefixTuningConfig,
|
||||
PromptEncoderConfig,
|
||||
PromptLearningConfig,
|
||||
@ -43,13 +45,6 @@ from peft.utils import _get_submodules, infer_device
|
||||
from .testing_utils import get_state_dict
|
||||
|
||||
|
||||
CONFIG_CLASSES = (
|
||||
IA3Config,
|
||||
LoraConfig,
|
||||
PrefixTuningConfig,
|
||||
PromptEncoderConfig,
|
||||
PromptTuningConfig,
|
||||
)
|
||||
CONFIG_TESTING_KWARGS = (
|
||||
# IA³
|
||||
{
|
||||
@ -269,7 +264,7 @@ class PeftCommonTester:
|
||||
|
||||
self.assertTrue(dummy_output.requires_grad)
|
||||
|
||||
def _test_save_pretrained(self, model_id, config_cls, config_kwargs):
|
||||
def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serialization=True):
|
||||
# ensure that the weights are randomly initialized
|
||||
if issubclass(config_cls, LoraConfig):
|
||||
config_kwargs = config_kwargs.copy()
|
||||
@ -287,7 +282,10 @@ class PeftCommonTester:
|
||||
model = model.to(self.torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model.save_pretrained(tmp_dirname)
|
||||
if safe_serialization:
|
||||
model.save_pretrained(tmp_dirname)
|
||||
else:
|
||||
model.save_pretrained(tmp_dirname, safe_serialization=False)
|
||||
|
||||
model_from_pretrained = self.transformers_class.from_pretrained(model_id)
|
||||
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
|
||||
@ -311,14 +309,16 @@ class PeftCommonTester:
|
||||
)
|
||||
)
|
||||
|
||||
# check if `adapter_model.bin` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
|
||||
target_adapter_filename = "adapter_model.safetensors" if safe_serialization else "adapter_model.bin"
|
||||
|
||||
# check if `adapter_model.safetensors` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, target_adapter_filename)))
|
||||
|
||||
# check if `adapter_config.json` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
|
||||
|
||||
# check if `pytorch_model.bin` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))
|
||||
# check if `model.safetensors` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
|
||||
|
||||
# check if `config.json` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
|
||||
@ -326,7 +326,7 @@ class PeftCommonTester:
|
||||
self.check_modelcard(tmp_dirname, model)
|
||||
self.check_config_json(tmp_dirname, model)
|
||||
|
||||
def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs):
|
||||
def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs, safe_serialization=True):
|
||||
if issubclass(config_cls, AdaLoraConfig):
|
||||
# AdaLora does not support adding more than 1 adapter
|
||||
return
|
||||
@ -355,7 +355,10 @@ class PeftCommonTester:
|
||||
model.add_adapter("new_adapter", new_adapter_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model.save_pretrained(tmp_dirname)
|
||||
if safe_serialization:
|
||||
model.save_pretrained(tmp_dirname)
|
||||
else:
|
||||
model.save_pretrained(tmp_dirname, safe_serialization=False)
|
||||
|
||||
model_from_pretrained = self.transformers_class.from_pretrained(model_id)
|
||||
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
|
||||
@ -385,17 +388,19 @@ class PeftCommonTester:
|
||||
)
|
||||
)
|
||||
|
||||
# check if `adapter_model.bin` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
|
||||
self.assertTrue(os.path.exists(os.path.join(new_adapter_dir, "adapter_model.bin")))
|
||||
target_adapter_filename = "adapter_model.safetensors" if safe_serialization else "adapter_model.bin"
|
||||
|
||||
# check if `adapter_model.safetensors` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, target_adapter_filename)))
|
||||
self.assertTrue(os.path.exists(os.path.join(new_adapter_dir, target_adapter_filename)))
|
||||
|
||||
# check if `adapter_config.json` is present
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
|
||||
self.assertTrue(os.path.exists(os.path.join(new_adapter_dir, "adapter_config.json")))
|
||||
|
||||
# check if `pytorch_model.bin` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))
|
||||
self.assertFalse(os.path.exists(os.path.join(new_adapter_dir, "pytorch_model.bin")))
|
||||
# check if `model.safetensors` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
|
||||
self.assertFalse(os.path.exists(os.path.join(new_adapter_dir, "model.safetensors")))
|
||||
|
||||
# check if `config.json` is not present
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
|
||||
@ -567,6 +572,71 @@ class PeftCommonTester:
|
||||
logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0]
|
||||
self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol))
|
||||
|
||||
def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs):
|
||||
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
|
||||
|
||||
if ("gpt2" in model_id.lower()) and (config_cls == IA3Config):
|
||||
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")
|
||||
|
||||
config = config_cls(
|
||||
base_model_name_or_path=model_id,
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
if config.peft_type not in supported_peft_types:
|
||||
return
|
||||
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
model = model.to(self.torch_device)
|
||||
|
||||
dummy_input = self.prepare_inputs_for_testing()
|
||||
model.eval()
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_adapter_1 = model(**dummy_input)[0]
|
||||
|
||||
model.add_adapter("adapter-2", config)
|
||||
model.set_adapter("adapter-2")
|
||||
model.eval()
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_adapter_2 = model(**dummy_input)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(logits_adapter_1, logits_adapter_2, atol=1e-3, rtol=1e-3))
|
||||
|
||||
model.set_adapter("default")
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_adapter_1_after_set = model(**dummy_input)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_adapter_1_after_set, logits_adapter_1, atol=1e-3, rtol=1e-3))
|
||||
|
||||
model_copy = copy.deepcopy(model)
|
||||
model_copy_2 = copy.deepcopy(model)
|
||||
model_merged_all = model.merge_and_unload(adapter_names=["adapter-2", "default"])
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_merged_all = model_merged_all(**dummy_input)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(logits_merged_all, logits_adapter_2, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(torch.allclose(logits_merged_all, logits_adapter_1, atol=1e-3, rtol=1e-3))
|
||||
|
||||
model_merged_adapter_2 = model_copy.merge_and_unload(adapter_names=["adapter-2"])
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_merged_adapter_2 = model_merged_adapter_2(**dummy_input)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_merged_adapter_2, logits_adapter_2, atol=1e-3, rtol=1e-3))
|
||||
|
||||
model_merged_adapter_default = model_copy_2.merge_and_unload(adapter_names=["default"])
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_merged_adapter_default = model_merged_adapter_default(**dummy_input)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_merged_adapter_default, logits_adapter_1, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def _test_generate(self, model_id, config_cls, config_kwargs):
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
config = config_cls(
|
||||
@ -815,42 +885,79 @@ class PeftCommonTester:
|
||||
self.assertIsNotNone(param.grad)
|
||||
|
||||
def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
|
||||
if issubclass(config_cls, AdaLoraConfig):
|
||||
# AdaLora does not support adding more than 1 adapter
|
||||
return
|
||||
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
|
||||
# IA3 does not support deleting adapters yet, but it just needs to be added
|
||||
# AdaLora does not support multiple adapters
|
||||
config = config_cls(
|
||||
base_model_name_or_path=model_id,
|
||||
**config_kwargs,
|
||||
)
|
||||
if config.peft_type not in supported_peft_types:
|
||||
return
|
||||
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
adapter_to_delete = "delete_me"
|
||||
model = get_peft_model(model, config)
|
||||
model.add_adapter(adapter_to_delete, config)
|
||||
model.set_adapter(adapter_to_delete)
|
||||
model = model.to(self.torch_device)
|
||||
model.delete_adapter(adapter_to_delete)
|
||||
self.assertFalse(adapter_to_delete in model.peft_config)
|
||||
self.assertEqual(model.active_adapters, ["default"])
|
||||
|
||||
if config.peft_type not in ("LORA"):
|
||||
with self.assertRaises(AttributeError):
|
||||
model.delete_adapter(adapter_to_delete)
|
||||
else:
|
||||
model.delete_adapter(adapter_to_delete)
|
||||
self.assertFalse(adapter_to_delete in model.peft_config)
|
||||
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(model, key)
|
||||
if isinstance(target, LoraLayer):
|
||||
for attr in [
|
||||
"r",
|
||||
"lora_alpha",
|
||||
"scaling",
|
||||
"lora_A",
|
||||
"lora_B",
|
||||
"lora_embedding_A",
|
||||
"lora_embedding_B",
|
||||
"lora_dropout",
|
||||
]:
|
||||
self.assertFalse(adapter_to_delete in getattr(target, attr))
|
||||
key_list = [key for key, _ in model.named_modules()]
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(model, key)
|
||||
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
|
||||
for attr in attributes_to_check:
|
||||
self.assertFalse(adapter_to_delete in getattr(target, attr))
|
||||
|
||||
# check that we can also delete the last remaining adapter
|
||||
model.delete_adapter("default")
|
||||
self.assertFalse("default" in model.peft_config)
|
||||
self.assertEqual(model.active_adapters, [])
|
||||
|
||||
input = self.prepare_inputs_for_testing()
|
||||
# note: we cannot call model(**input) because PeftModel always expects there to be at least one adapter
|
||||
model.base_model(**input) # should not raise an error
|
||||
|
||||
def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
|
||||
# same as test_delete_adapter, but this time an inactive adapter is deleted
|
||||
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
|
||||
# IA3 does not support deleting adapters yet, but it just needs to be added
|
||||
# AdaLora does not support multiple adapters
|
||||
config = config_cls(
|
||||
base_model_name_or_path=model_id,
|
||||
**config_kwargs,
|
||||
)
|
||||
if config.peft_type not in supported_peft_types:
|
||||
return
|
||||
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
adapter_to_delete = "delete_me"
|
||||
model = get_peft_model(model, config)
|
||||
model.add_adapter(adapter_to_delete, config)
|
||||
# "delete_me" is added but not activated
|
||||
model = model.to(self.torch_device)
|
||||
model.delete_adapter(adapter_to_delete)
|
||||
self.assertFalse(adapter_to_delete in model.peft_config)
|
||||
self.assertEqual(model.active_adapters, ["default"])
|
||||
|
||||
key_list = [key for key, _ in model.named_modules()]
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(model, key)
|
||||
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
|
||||
for attr in attributes_to_check:
|
||||
self.assertFalse(adapter_to_delete in getattr(target, attr))
|
||||
|
||||
# check that we can also delete the last remaining adapter
|
||||
model.delete_adapter("default")
|
||||
self.assertFalse("default" in model.peft_config)
|
||||
self.assertEqual(model.active_adapters, [])
|
||||
|
||||
input = self.prepare_inputs_for_testing()
|
||||
# note: we cannot call model(**input) because PeftModel always expects there to be at least one adapter
|
||||
model.base_model(**input) # should not raise an error
|
||||
|
||||
def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
@ -861,12 +968,12 @@ class PeftCommonTester:
|
||||
model = get_peft_model(model, config)
|
||||
model = model.to(self.torch_device)
|
||||
|
||||
if config.peft_type not in ("LORA", "ADALORA"):
|
||||
if config.peft_type not in ("LORA", "ADALORA", "IA3"):
|
||||
with self.assertRaises(AttributeError):
|
||||
model = model.unload()
|
||||
else:
|
||||
dummy_input = self.prepare_inputs_for_testing()
|
||||
logits_with_lora = model(**dummy_input)[0]
|
||||
logits_with_adapter = model(**dummy_input)[0]
|
||||
|
||||
transformers_model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
|
||||
logits_transformers = transformers_model(**dummy_input)[0]
|
||||
@ -875,7 +982,7 @@ class PeftCommonTester:
|
||||
model = model.unload()
|
||||
logits_unload = model(**dummy_input)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(logits_with_lora, logits_unload, atol=1e-10, rtol=1e-10))
|
||||
self.assertFalse(torch.allclose(logits_with_adapter, logits_unload, atol=1e-10, rtol=1e-10))
|
||||
self.assertTrue(torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
|
||||
@ -885,13 +992,14 @@ class PeftCommonTester:
|
||||
|
||||
adapter_list = ["adapter1", "adapter_2", "adapter_3"]
|
||||
weight_list = [0.5, 1.5, 1.5]
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
config = config_cls(
|
||||
base_model_name_or_path=model_id,
|
||||
**config_kwargs,
|
||||
)
|
||||
if not isinstance(config, (LoraConfig)):
|
||||
return
|
||||
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
model = get_peft_model(model, config, adapter_list[0])
|
||||
model.add_adapter(adapter_list[1], config)
|
||||
model.add_adapter(adapter_list[2], replace(config, r=20))
|
||||
@ -930,7 +1038,7 @@ class PeftCommonTester:
|
||||
for new_adapter in new_adapters:
|
||||
self.assertTrue(new_adapter in model.peft_config)
|
||||
|
||||
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
|
||||
key_list = [key for key, _ in model.named_modules()]
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(model, key)
|
||||
if isinstance(target, LoraLayer):
|
||||
@ -1006,7 +1114,7 @@ class PeftCommonTester:
|
||||
# must be False
|
||||
if isinstance(peft_model, StableDiffusionPipeline):
|
||||
# for SD, check that most pixels have different values
|
||||
self.assertTrue((output_before != output_peft).float().mean() > 0.9)
|
||||
self.assertTrue((output_before != output_peft).float().mean() > 0.8)
|
||||
else:
|
||||
self.assertFalse(torch.allclose(output_before, output_peft))
|
||||
|
||||
|
Reference in New Issue
Block a user