Compare commits

...

25 Commits

Author SHA1 Message Date
b169484659 [Docs fix] Relative path issue 2023-11-21 10:31:18 +01:00
8351331d78 ENH Delete IA3 adapters (#1153) 2023-11-20 18:22:52 +01:00
f1ecfa6ae6 Use huggingface_hub.file_exists instead of custom helper (#1145)
* Use 'huggingface_hub.file_exists' instead of custom helper

* make quality
2023-11-17 15:48:02 +01:00
b5a8a294ed FIX A few issues with AdaLora, adding tests (#1146)
This PR fixes a handful of issues with AdaLora, should resolve #1113.

Description

1. lora_A.weight.device was called but for AdaLora, lora_A is a
   nn.Paramter, not an nn.Module, so the weight attribute does not
   exist. lora_A.device is sufficient.
2. For 8bit, an inplace operation failed because it was on a view. Now
   the operation is no longer inplace.
3. The loss term of the model output is not necessarily a torch tensor.
   In the test, it was a dict and did not contain an actual loss.
   Therefore, I added a check to make sure the loss is a torch tensor.
2023-11-17 15:18:34 +01:00
9cdaed2769 CI Add Python 3.11 to test matrix (#1143)
Only required change was to call .value on some enums when used in
messages, as their repr has changed in Python 3.11.
2023-11-17 14:11:54 +01:00
18a0910113 [Tests] Do not stop tests if a job failed (#1141)
* Update nightly.yml

* Update nightly.yml
2023-11-16 18:11:19 +01:00
99e1a55f54 [core / LoRA] Add adapter_names in bnb layers (#1139)
* Update bnb.py

* fix style
2023-11-16 17:12:39 +01:00
21df968fd1 [Tests] Fix daily CI (#1136)
* fix daily CI

* adapt from suggestion
2023-11-16 14:43:36 +01:00
5a3a5acff2 Refactor base layer pattern (#1106)
Description

Refactor all tuners (where it applies, i.e. not prompt tuning) to use
the "base layer pattern". This means that the adapter layer will always
hold a reference to the original layer that it modifies. This pattern is
already partly used (e.g. LoRA bnb, gptq layers), now it is consistently
used everywhere when applicable.

This PR is a companion PR to #1069, where I first added these changes.
They are now extracted to a separate PR to make code review easier and
to advance more quickly.

Implementation

The main change is that the adapter layer wraps the original layer and
calls forward on that layer, instead of doing stuff like this:

F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

which completely circumvents the call to the target layer's forward
method. With the base layer pattern, we now call the target layer's
forward method. Therefore, if the target layer is another adapter
layer (which will be crucial for mixed adapters), we call its forward
method correctly. Also, this should allow passing extra arguments, like
lora_scale to forward.

This change has the nice side benefit that we no longer need to use
_init_empty_weights -- in fact, we don't initialize any of the target
layer's weights anymore, since we have a reference to it. There is thus
no risk of having slow but superfluous initialization of layers.

Moreover, I could greatly simplify merge_and_unload by just using the
base_layer instead of having to create a completely new layer. For
OPT-350m, this results in a 15x speedup.

Note that same as for the bnb layers, this should be backwards
incompatible, since the adapter weights and their state_dicts are not
affected by this change. I used #1115 for regression testing.

Somewhat unrelated changes

During debugging, I got very annoyed with the fact that the reprs of
adapter layers and normal PyTorch layers are hard to distinguish, e.g.
the type is just "Linear". Now, for adapter layers, it is prefixed by
the adapter type, e.g. "lora.Linear". This should have no further
implications except for the repr (e.g. state_dict remains unaffected).

For LoHa and LoKr, I had to change the init of weights when using
init_weights=False. This is because of what is discussed in Numerical
instabilities with LoHa #1058.

IA³ now has the unload method too.

LoHa and LoKr now support safe_merge=True when merging layers.

Migration guide

For 99% of users, the code should continue working as ususal, because
the API stays the same. Only low level details have been changed.

Code that relies on isinstance checks on specific PEFT classes may
break. E.g. the LoRA Linear layer no longer inherits from nn.Linear. It
is, however, still a BaseTunerLayer. The same logic applies for other
layer types like Conv2d and for other tuners like IA³.

To retrieve the base layer of an adapter layer, you should now call
module.get_base_layer() if you deal with a BaseTunerLayer. Don't rely on
something like module.weight being present (though it might be).
2023-11-16 12:45:12 +01:00
70302d7b4f FEAT: Merging only specified adapter_names when calling merge (#1132)
* working v1

* add tests

* remove

* add it also for lokr and loha, left a todo

* Update tests/testing_common.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* better test

* up

* fix tests

* credits contrib and suggestions from disscussions

* credits contrib and suggestions from disscussions

* address last comments

---------

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
2023-11-16 12:05:22 +01:00
3ff90626b6 FEAT: Make safe serialization the default one (#1088)
* make safe serialization the default one

* adapt tests

* fix final tests'

* adapt from suggestion
2023-11-15 11:21:23 +01:00
1877329093 TST Improve requires grad testing: (#1131)
Previously, the corresponding tests were testing only whether specific
parameters had requires_grad True or False. Now, all parameters are
being checked. This is more rigorous.

Also, tests for Embedding, Conv1D, Conv2d were added, thus superseding
PR #1115.

Finally, tests for LoHa and LoKr were added.

Note

I considered moving the tests to a separate module, as they were getting
quite big and this would help with readability. For now, I left them in
the same module because it leads to a better diff view and is thus
easier to review. LMK if I should move the tests to a separate file.
2023-11-14 17:44:49 +05:30
98429b8184 Fix: TorchTracemalloc ruins Windows performance (#1126)
* feat: added tracemalloc arg to train_dreambooth

* fix: added help for arg

* fix: changed arg name

* fix formatting

* fix: import order
2023-11-14 17:04:32 +05:30
d350a00ece Prompt tuning: fix AutoTokenizer.from_pretrained (#1053)
Fixes #1032

Description

Currently, when using prompt tuning with TEXT, we call
AutoTokenizer.from_pretrained with only the model id. However, it may be
necessary to pass additional arguments, e.g. trust_remote_code=True.
This fix allows to pass more arguments by setting the argument
tokenizer_kwargs in the PromptTuningConfig.

I also added a check that when tokenizer_kwargs is set, the TEXT option
is actually being used.

Moreover, I noticed that we have no tests for prompt tuning with TEXT,
so I added those tests for decoder models.

Additional changes

There was a bug in PromptEmbedding where the device of the
init_token_ids was not set, which resulted in errors when using CUDA.

Finally, I removed an unused constant CONFIG_CLASSES from a test.
2023-11-14 16:58:55 +05:30
ad756173f1 FIX: Adding 2 adapters when target_modules is a str fails (#1111)
* Fix adding 2 adapters when target_modules is a str

Problem description

Adding two adapters (e.g. LoRA) when using a list for `target_mdules`
works but passing a str fails. The issue is that for str, we do a
`re.fullmatch`, whereas for list, we just check `endswith`. After adding
the first adapter, though, the naming pattern of the modules changes. In
the example above, the name for the linear layer changes from `"lin0"`
to `"base_model.model.lin0"`, which is why the `fullmatch` fails but the
`endswith` still works.

Reproduction

from peft import LoraConfig, get_peft_model
from torch import nn

class MLP(nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.lin0 = nn.Linear(10, 20, bias=bias)

def test_target_modules_list():
    config = LoraConfig(target_modules=["lin0"])
    test_it(config)
    print("Adding two adapters with target_module being a list works")

def test_target_modules_str():
    config = LoraConfig(target_modules="lin0")
    test_it(config)

def test_it(config):
    model = MLP()
    model = get_peft_model(model, config, "adapter0")
    model.add_adapter("adapter1", config)
    print("Adding two adapters with target_module being a str works")

if __name__ == "__main__":
    # works
    test_target_modules_list()
    # ValueError: Target modules lin0 not found in the base model
    test_target_modules_str()

I think that most users would be surprised that:

1. Adding the first adapter works but adding the second fails, even
   though they use the same config.
2. Using `target_modules=["lin0"]` works but `target_modules="lin0"`
   fails for the 2nd adapter.

Solution

We could change the logic of not using `re.fullmatch` for str, but I
think that could be tricky to achieve without breaking BC. Instead, I
chose to change the inject_adapter call in add_adapter to pass the base
model, not the whole peft model. This way, the naming pattern is
preserved.

Tests

I haven't added extra tests for this. The script above could serve as a
test. However, it will be sufficient to remove the guard added in #1105:

    if isinstance(config.target_str, modules):
        # TODO this should be doable
        self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.")

as that will test exactly this behavior and was how the bug was
originally uncovered. Depending on what PR lands first, the guard has to
removed in this PR or in #1105.

* Enable tests for adding 2 adapters with str
2023-11-14 15:00:52 +05:30
94877b5008 Release: v0.6.3.dev0 (#1128) 2023-11-14 14:59:55 +05:30
f020404ee6 Release: v0.6.2 (#1125) 2023-11-14 11:13:21 +05:30
ChG
79298c7c24 fix doc typo (#1121) 2023-11-13 10:48:50 +01:00
b25ce8a0cd Correctly deal with ModulesToSaveWrapper when using Low-level API (#1112)
* correctly deal with  `ModulesToSaveWrapper`

* style

* fix tests (#1117)
2023-11-13 12:22:30 +05:30
5d84484079 fix import issue transformers (#1116) 2023-11-10 18:37:38 +01:00
49ddefa834 Add num_dataloader_workers arg to dreambooth script (#1107)
This is especially important for Windows users, who may have to set the
number of workers to 0.
2023-11-10 14:21:14 +01:00
3af469eeea Refactor adapter deletion (#1105)
Description

The job of deleting an adapter is now transferred to the adapter layer,
instead of the adapter model. This makes it easier for users or other
libraries who don't use the adapter model to delete adapters.

Implementation

The code should now be more generic, relying less on hard-coded
attributes.

As a precaution, I also changed the type of adapter_layer_names from
list to tuple, as it should not be mutated.

When deleting the active adapter, the logic for choosing the new active
adapter has been changed slightly to ensure consistency across layers.
In practice, this should rarely make a difference. An error is now
raised if the last remaining adapter is deleted.

Test coverage has been increased:

- Deleting adapters is now also tested for custom models.
- It is also tested for LoHa, LoKr, not only LoRA.
- I added a test for deleting the non-active adapter.

Not implemented

I did not add adapter deletion to IA³, since it is included in #980. LMK
if it should be added here instead.
2023-11-10 13:33:56 +01:00
5e7e5ad836 Avoid over-eager auto-gptq import (#1109) 2023-11-10 12:35:18 +01:00
9d8287f3e3 set dev version (#1104) 2023-11-09 15:44:28 +01:00
2efd02769b Release: 0.6.1 (#1103) 2023-11-09 15:16:33 +01:00
44 changed files with 2452 additions and 1259 deletions

View File

@ -15,6 +15,8 @@ env:
jobs: jobs:
run_all_tests_single_gpu: run_all_tests_single_gpu:
strategy:
fail-fast: false
runs-on: [self-hosted, docker-gpu, multi-gpu] runs-on: [self-hosted, docker-gpu, multi-gpu]
env: env:
CUDA_VISIBLE_DEVICES: "0" CUDA_VISIBLE_DEVICES: "0"
@ -57,6 +59,8 @@ jobs:
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
run_all_tests_multi_gpu: run_all_tests_multi_gpu:
strategy:
fail-fast: false
runs-on: [self-hosted, docker-gpu, multi-gpu] runs-on: [self-hosted, docker-gpu, multi-gpu]
env: env:
CUDA_VISIBLE_DEVICES: "0,1" CUDA_VISIBLE_DEVICES: "0,1"

View File

@ -28,7 +28,7 @@ jobs:
needs: check_code_quality needs: check_code_quality
strategy: strategy:
matrix: matrix:
python-version: ["3.8", "3.9", "3.10"] python-version: ["3.8", "3.9", "3.10", "3.11"]
os: ["ubuntu-latest", "macos-latest", "windows-latest"] os: ["ubuntu-latest", "macos-latest", "windows-latest"]
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:

View File

@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
Some fine-tuning techniques, such as prompt tuning, are specific to language models. That means in 🤗 PEFT, it is Some fine-tuning techniques, such as prompt tuning, are specific to language models. That means in 🤗 PEFT, it is
assumed a 🤗 Transformers model is being used. However, other fine-tuning techniques - like assumed a 🤗 Transformers model is being used. However, other fine-tuning techniques - like
[LoRA](./conceptual_guides/lora) - are not restricted to specific model types. [LoRA](../conceptual_guides/lora) - are not restricted to specific model types.
In this guide, we will see how LoRA can be applied to a multilayer perceptron and a computer vision model from the [timm](https://huggingface.co/docs/timm/index) library. In this guide, we will see how LoRA can be applied to a multilayer perceptron and a computer vision model from the [timm](https://huggingface.co/docs/timm/index) library.

View File

@ -17,7 +17,7 @@ The development of this API has been motivated by the need for super users to no
## Supported tuner types ## Supported tuner types
Currently the supported adapter types are the 'injectable' adapters, meaning adapters where an inplace modification of the model is sufficient to correctly perform the fine tuning. As such, only [LoRA](./conceptual_guides/lora), AdaLoRA and [IA3](./conceptual_guides/ia3) are currently supported in this API. Currently the supported adapter types are the 'injectable' adapters, meaning adapters where an inplace modification of the model is sufficient to correctly perform the fine tuning. As such, only [LoRA](../conceptual_guides/lora), AdaLoRA and [IA3](../conceptual_guides/ia3) are currently supported in this API.
## `inject_adapter_in_model` method ## `inject_adapter_in_model` method

View File

@ -83,6 +83,7 @@ accelerate launch train_dreambooth.py \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--train_text_encoder \ --train_text_encoder \
--with_prior_preservation --prior_loss_weight=1.0 \ --with_prior_preservation --prior_loss_weight=1.0 \
--num_dataloader_workers=1 \
--instance_prompt="a photo of sks dog" \ --instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \ --class_prompt="a photo of dog" \
--resolution=512 \ --resolution=512 \
@ -101,6 +102,8 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800 --max_train_steps=800
``` ```
If you are running this script on Windows, you may need to set the `--num_dataloader_workers` to 0.
## Inference with a single adapter ## Inference with a single adapter
To run inference with the fine-tuned model, first specify the base model with which the fine-tuned LoRA weights will be combined: To run inference with the fine-tuned model, first specify the base model with which the fine-tuned LoRA weights will be combined:
@ -171,7 +174,7 @@ image.save("DESTINATION_PATH_FOR_THE_IMAGE")
## Multi-adapter inference ## Multi-adapter inference
With PEFT you can combine multiple adapters for inference. In the previous example you have fine-tuned Stable Diffusion on With PEFT you can combine multiple adapters for inference. In the previous example you have fine-tuned Stable Diffusion on
some dog images. The pipeline created based on these weights got a name - `adapter_name="dog`. Now, suppose you also fine-tuned some dog images. The pipeline created based on these weights got a name - `adapter_name="dog"`. Now, suppose you also fine-tuned
this base model on images of a crochet toy. Let's see how we can use both adapters. this base model on images of a crochet toy. Let's see how we can use both adapters.
First, you'll need to perform all the steps as in the single adapter inference example: First, you'll need to perform all the steps as in the single adapter inference example:

View File

@ -7,6 +7,7 @@ import math
import os import os
import threading import threading
import warnings import warnings
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -213,6 +214,17 @@ def parse_args(input_args=None):
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
) )
parser.add_argument(
"--num_dataloader_workers", type=int, default=1, help="Num of workers for the training dataloader."
)
parser.add_argument(
"--no_tracemalloc",
default=False,
action="store_true",
help="Flag to stop memory allocation tracing during training. This could speed up training on Windows.",
)
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
) )
@ -799,7 +811,7 @@ def main(args):
batch_size=args.train_batch_size, batch_size=args.train_batch_size,
shuffle=True, shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1, num_workers=args.num_dataloader_workers,
) )
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
@ -893,7 +905,7 @@ def main(args):
unet.train() unet.train()
if args.train_text_encoder: if args.train_text_encoder:
text_encoder.train() text_encoder.train()
with TorchTracemalloc() as tracemalloc: with TorchTracemalloc() if not args.no_tracemalloc else nullcontext() as tracemalloc:
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step # Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@ -1034,23 +1046,29 @@ def main(args):
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
accelerator.print(
"GPU Total Peak Memory consumed during the train (max): {}".format(
tracemalloc.peaked + b2mb(tracemalloc.begin)
)
)
accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin))) if not args.no_tracemalloc:
accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used)) accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked)) accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
accelerator.print( accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
"CPU Total Peak Memory consumed during the train (max): {}".format( accelerator.print(
tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin) "GPU Total Peak Memory consumed during the train (max): {}".format(
tracemalloc.peaked + b2mb(tracemalloc.begin)
)
)
accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
accelerator.print(
"CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used)
)
accelerator.print(
"CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked)
)
accelerator.print(
"CPU Total Peak Memory consumed during the train (max): {}".format(
tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
)
) )
)
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@ -22,7 +22,7 @@ extras["test"] = extras["dev"] + ["pytest", "pytest-cov", "pytest-xdist", "param
setup( setup(
name="peft", name="peft",
version="0.6.1.dev0", version="0.6.3.dev0",
description="Parameter-Efficient Fine-Tuning (PEFT)", description="Parameter-Efficient Fine-Tuning (PEFT)",
license_files=["LICENSE"], license_files=["LICENSE"],
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
@ -47,6 +47,7 @@ setup(
"tqdm", "tqdm",
"accelerate>=0.21.0", "accelerate>=0.21.0",
"safetensors", "safetensors",
"huggingface_hub>=0.17.0",
], ],
extras_require=extras, extras_require=extras,
classifiers=[ classifiers=[

View File

@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.6.1.dev0" __version__ = "0.6.3.dev0"
from .auto import ( from .auto import (
AutoPeftModel, AutoPeftModel,

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import importlib.metadata as importlib_metadata import importlib.metadata as importlib_metadata
from functools import lru_cache
import packaging.version import packaging.version
@ -46,3 +47,20 @@ def is_auto_gptq_available():
def is_optimum_available() -> bool: def is_optimum_available() -> bool:
return importlib.util.find_spec("optimum") is not None return importlib.util.find_spec("optimum") is not None
@lru_cache()
def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
if importlib.util.find_spec("torch_xla") is not None:
if check_device:
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
import torch_xla.core.xla_model as xm
_ = xm.xla_device()
return True
except RuntimeError:
return False
return True
return False

View File

@ -32,7 +32,6 @@ from safetensors.torch import save_file as safe_save_file
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.pytorch_utils import id_tensor_storage
from transformers.utils import PushToHubMixin from transformers.utils import PushToHubMixin
from . import __version__ from . import __version__
@ -60,6 +59,7 @@ from .utils import (
_set_adapter, _set_adapter,
_set_trainable, _set_trainable,
get_peft_model_state_dict, get_peft_model_state_dict,
id_tensor_storage,
infer_device, infer_device,
load_peft_weights, load_peft_weights,
set_peft_model_state_dict, set_peft_model_state_dict,
@ -157,7 +157,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
def save_pretrained( def save_pretrained(
self, self,
save_directory: str, save_directory: str,
safe_serialization: bool = False, safe_serialization: bool = True,
selected_adapters: Optional[List[str]] = None, selected_adapters: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
): ):
@ -573,7 +573,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
self.base_model.add_adapter(adapter_name, peft_config) self.base_model.add_adapter(adapter_name, peft_config)
else: else:
self.peft_config[adapter_name] = peft_config self.peft_config[adapter_name] = peft_config
self.base_model.inject_adapter(self, adapter_name) self.base_model.inject_adapter(self.base_model.model, adapter_name)
except Exception: # somthing went wrong, roll back except Exception: # somthing went wrong, roll back
if adapter_name in self.peft_config: if adapter_name in self.peft_config:
del self.peft_config[adapter_name] del self.peft_config[adapter_name]

View File

@ -27,10 +27,3 @@ from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparamet
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
# Mapping of tuners that support direct plugging
TUNERS_MAPPING = {
"LORA": LoraModel,
"IA3": IA3Model,
"ADALORA": AdaLoraModel,
}

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import bitsandbytes as bnb from typing import Any
import torch import torch
from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.import_utils import is_bnb_4bit_available, is_bnb_available
@ -23,38 +24,28 @@ from .layer import AdaLoraLayer
if is_bnb_available(): if is_bnb_available():
class SVDLinear8bitLt(bnb.nn.Linear8bitLt, AdaLoraLayer): class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation # Low-rank matrix for SVD-based adaptation
def __init__( def __init__(
self, self,
adapter_name, base_layer: torch.nn.Module,
in_features, adapter_name: str,
out_features,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
bnb.nn.Linear8bitLt.__init__( super().__init__()
self, AdaLoraLayer.__init__(self, base_layer)
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix # Freezing the pre-trained weight matrix
self.weight.requires_grad = False self.get_base_layer().weight.requires_grad = False
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
result = super().forward(x) # note: no check for self.merged because merging is not supported (yet)
result = self.base_layer(x)
if self.disable_adapters: if self.disable_adapters:
return result return result
@ -79,43 +70,39 @@ if is_bnb_available():
if requires_conversion: if requires_conversion:
output = output.to(expected_dtype) output = output.to(expected_dtype)
output = output * scaling / ranknum output = output * scaling / ranknum
result += output # inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it
result = result + output
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep
if is_bnb_4bit_available(): if is_bnb_4bit_available():
class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer): class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation # Low-rank matrix for SVD-based adaptation
def __init__( def __init__(
self, self,
adapter_name, base_layer: torch.nn.Module,
in_features, adapter_name: str,
out_features,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
bnb.nn.Linear4bit.__init__( super().__init__()
self, AdaLoraLayer.__init__(self, base_layer)
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix # Freezing the pre-trained weight matrix
self.weight.requires_grad = False self.get_base_layer().weight.requires_grad = False
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = super().forward(x) # note: no check for self.merged because merging is not supported (yet)
result = self.base_layer(x, *args, **kwargs)
if self.disable_adapters: if self.disable_adapters:
return result return result
@ -141,7 +128,7 @@ if is_bnb_4bit_available():
requires_conversion = not torch.is_autocast_enabled() requires_conversion = not torch.is_autocast_enabled()
if requires_conversion: if requires_conversion:
expected_dtype = result.dtype expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype compute_dtype = lora_A.dtype
if x.dtype != compute_dtype: if x.dtype != compute_dtype:
x = x.to(compute_dtype) x = x.to(compute_dtype)
@ -151,3 +138,7 @@ if is_bnb_4bit_available():
output = output * scaling / ranknum output = output * scaling / ranknum
result += output result += output
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep

View File

@ -20,22 +20,21 @@ from .layer import AdaLoraLayer
class SVDQuantLinear(torch.nn.Module, AdaLoraLayer): class SVDQuantLinear(torch.nn.Module, AdaLoraLayer):
def __init__( def __init__(
self, self,
base_layer,
adapter_name, adapter_name,
quant_linear_module,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
torch.nn.Module.__init__(self) super().__init__()
AdaLoraLayer.__init__( AdaLoraLayer.__init__(self, base_layer)
self, in_features=quant_linear_module.infeatures, out_features=quant_linear_module.outfeatures
) # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
self.quant_linear_module = quant_linear_module # for backwards compatibility
self.weight = quant_linear_module.qweight self.quant_linear_module = base_layer
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
result = self.quant_linear_module(x) result = self.quant_linear_module(x)
@ -67,3 +66,7 @@ class SVDQuantLinear(torch.nn.Module, AdaLoraLayer):
output = output.to(expected_dtype) output = output.to(expected_dtype)
result += output result += output
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep

View File

@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from typing import Any, List, Optional
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
@ -26,14 +26,11 @@ from peft.utils import transpose
class AdaLoraLayer(LoraLayer): class AdaLoraLayer(LoraLayer):
# List all names of layers that may contain adapter weights # List all names of layers that may contain adapter weights
# Note: ranknum doesn't need to be included as it is not an nn.Module # Note: ranknum doesn't need to be included as it is not an nn.Module
adapter_layer_names = ["lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B"] adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B")
# other_param_names is defined in LoraLayer
def __init__( def __init__(self, base_layer: nn.Module) -> None:
self, super().__init__(base_layer)
in_features: int,
out_features: int,
):
super().__init__(in_features, out_features)
self.lora_E = nn.ParameterDict({}) self.lora_E = nn.ParameterDict({})
self.lora_A = nn.ParameterDict({}) self.lora_A = nn.ParameterDict({})
self.lora_B = nn.ParameterDict({}) self.lora_B = nn.ParameterDict({})
@ -62,7 +59,12 @@ class AdaLoraLayer(LoraLayer):
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r) self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
if init_lora_weights: if init_lora_weights:
self.reset_lora_parameters(adapter_name) self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)
if hasattr(self.get_base_layer(), "qweight"):
# QuantLinear
self.to(self.get_base_layer().qweight.device)
else:
self.to(self.get_base_layer().weight.device)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters)
def reset_lora_parameters(self, adapter_name): def reset_lora_parameters(self, adapter_name):
@ -72,34 +74,29 @@ class AdaLoraLayer(LoraLayer):
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02) nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02)
class SVDLinear(nn.Linear, AdaLoraLayer): class SVDLinear(nn.Module, AdaLoraLayer):
# SVD-based adaptation by a dense layer # SVD-based adaptation by a dense layer
def __init__( def __init__(
self, self,
base_layer: nn.Module,
adapter_name: str, adapter_name: str,
in_features: int,
out_features: int,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, fan_in_fan_out: bool = False,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
init_lora_weights = kwargs.pop("init_lora_weights", True) super().__init__()
nn.Linear.__init__(self, in_features, out_features, **kwargs) AdaLoraLayer.__init__(self, base_layer)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix # Freezing the pre-trained weight matrix
self.weight.requires_grad = False self.get_base_layer().weight.requires_grad = False
self.fan_in_fan_out = fan_in_fan_out self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out: self._active_adapter = adapter_name
self.weight.data = self.weight.data.T
nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def merge(self, safe_merge: bool = False) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights into the base weights Merge the active adapter weights into the base weights
@ -108,18 +105,26 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters:
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
base_layer = self.get_base_layer()
if active_adapter in self.lora_A.keys(): if active_adapter in self.lora_A.keys():
if safe_merge: if safe_merge:
# Note that safe_merge will be slower than the normal merge # Note that safe_merge will be slower than the normal merge
# because of the copy operation. # because of the copy operation.
orig_weights = self.weight.data.clone() orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter) orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all(): if not torch.isfinite(orig_weights).all():
@ -127,9 +132,9 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.weight.data = orig_weights base_layer.weight.data = orig_weights
else: else:
self.weight.data += self.get_delta_weight(active_adapter) base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
def unmerge(self) -> None: def unmerge(self) -> None:
@ -139,7 +144,7 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
while len(self.merged_adapters) > 0: while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop() active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys(): if active_adapter in self.lora_A.keys():
self.weight.data -= self.get_delta_weight(active_adapter) self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def get_delta_weight(self, adapter) -> torch.Tensor: def get_delta_weight(self, adapter) -> torch.Tensor:
return ( return (
@ -148,19 +153,16 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
/ (self.ranknum[adapter] + 1e-5) / (self.ranknum[adapter] + 1e-5)
) )
def _linear(self, input: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# TODO: SVDLinear does not convert dtype, unlike lora linear, is that correct? # TODO: SVDLinear does not convert dtype, unlike lora linear, is that correct?
if self.disable_adapters: if self.disable_adapters:
if self.merged: if self.merged:
self.unmerge() self.unmerge()
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
elif self.merged: elif self.merged:
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
else: else:
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys(): if active_adapter not in self.lora_A.keys():
continue continue
@ -175,8 +177,12 @@ class SVDLinear(nn.Linear, AdaLoraLayer):
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep
class RankAllocator(object):
class RankAllocator:
""" """
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY

View File

@ -20,6 +20,7 @@ from transformers.pytorch_utils import Conv1D
from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.lora import LoraConfig, LoraModel from peft.tuners.lora import LoraConfig, LoraModel
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ( from peft.utils import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
_freeze_adapter, _freeze_adapter,
@ -67,6 +68,8 @@ class AdaLoraModel(LoraModel):
- **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model. - **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model.
""" """
# Note: don't redefine prefix here, it should be inherited from LoraModel
def __init__(self, model, config, adapter_name): def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name) super().__init__(model, config, adapter_name)
@ -121,7 +124,7 @@ class AdaLoraModel(LoraModel):
loaded_in_4bit = optional_kwargs.get("loaded_in_4bit", False) loaded_in_4bit = optional_kwargs.get("loaded_in_4bit", False)
if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available(): if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available():
raise ImportError( raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " "To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`." "You can install it with `pip install bitsandbytes`."
) )
kwargs = { kwargs = {
@ -138,7 +141,7 @@ class AdaLoraModel(LoraModel):
if quantization_config is not None: if quantization_config is not None:
kwargs["gptq_quantization_config"] = quantization_config kwargs["gptq_quantization_config"] = quantization_config
# If it is not a LoraLayer, create a new module, else update it with new adapters # If it is not an AdaLoraLayer, create a new module, else update it with new adapters
if not isinstance(target, AdaLoraLayer): if not isinstance(target, AdaLoraLayer):
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter: if adapter_name != self.active_adapter:
@ -159,11 +162,15 @@ class AdaLoraModel(LoraModel):
gptq_quantization_config = kwargs.get("gptq_quantization_config", None) gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
bias = target.bias is not None
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
kwargs.update( kwargs.update(
{ {
"has_fp16_weights": target.state.has_fp16_weights, "has_fp16_weights": target.state.has_fp16_weights,
@ -172,8 +179,8 @@ class AdaLoraModel(LoraModel):
"index": target.index, "index": target.index,
} }
) )
new_module = SVDLinear8bitLt(adapter_name, target.in_features, target.out_features, bias=bias, **kwargs) new_module = SVDLinear8bitLt(target, adapter_name, **kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy() fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update( fourbit_kwargs.update(
{ {
@ -182,25 +189,18 @@ class AdaLoraModel(LoraModel):
"quant_type": target.weight.quant_type, "quant_type": target.weight.quant_type,
} }
) )
new_module = SVDLinear4bit( new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
)
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
new_module = SVDQuantLinear(adapter_name, target, **kwargs) new_module = SVDQuantLinear(target, adapter_name, **kwargs)
target.weight = target.qweight
else: else:
if isinstance(target, torch.nn.Linear): if isinstance(target_base_layer, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]: if kwargs["fan_in_fan_out"]:
warnings.warn( warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False." "Setting fan_in_fan_out to False."
) )
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
elif isinstance(target, Conv1D): elif isinstance(target_base_layer, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
if not kwargs["fan_in_fan_out"]: if not kwargs["fan_in_fan_out"]:
warnings.warn( warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. " "fan_in_fan_out is set to False but the target module is `Conv1D`. "
@ -212,7 +212,7 @@ class AdaLoraModel(LoraModel):
f"Target module {target} is not supported. " f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
) )
new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs) new_module = SVDLinear(target, adapter_name, **kwargs)
return new_module return new_module
@ -236,7 +236,7 @@ class AdaLoraModel(LoraModel):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
outputs = self.model.forward(*args, **kwargs) outputs = self.model.forward(*args, **kwargs)
if getattr(outputs, "loss", None) is not None: if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor):
# Calculate the orthogonal regularization # Calculate the orthogonal regularization
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import bitsandbytes as bnb from typing import Any
import torch import torch
from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.import_utils import is_bnb_4bit_available, is_bnb_available
@ -23,39 +24,27 @@ from .layer import IA3Layer
if is_bnb_available(): if is_bnb_available():
class Linear8bitLt(bnb.nn.Linear8bitLt, IA3Layer): class Linear8bitLt(torch.nn.Module, IA3Layer):
# (IA)^3 implemented in a dense layer # (IA)^3 implemented in a dense layer
def __init__( def __init__(
self, self,
adapter_name, base_layer: torch.nn.Module,
in_features, adapter_name: str,
out_features, is_feedforward: bool,
is_feedforward, init_ia3_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
bnb.nn.Linear8bitLt.__init__( super().__init__()
self, IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward
# Freezing the pre-trained weight matrix # Freezing the pre-trained weight matrix
self.weight.requires_grad = False self.get_base_layer().weight.requires_grad = False
init_ia3_weights = kwargs.pop("init_ia3_weights", True)
self.update_layer(adapter_name, init_ia3_weights) self.update_layer(adapter_name, init_ia3_weights)
self.set_adapter(adapter_name)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# note: no check for self.merged because merging is not supported (yet)
if self.disable_adapters: if self.disable_adapters:
return super().forward(x) return self.base_layer(x)
ia3_scaling = 1 ia3_scaling = 1
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
@ -67,10 +56,10 @@ if is_bnb_available():
if requires_conversion: if requires_conversion:
x = x.float() x = x.float()
if self.is_feedforward: if self.is_feedforward:
result = super().forward(x * ia3_scaling) result = self.base_layer(x * ia3_scaling)
expected_dtype = result.dtype expected_dtype = result.dtype
else: else:
result = super().forward(x) result = self.base_layer(x)
expected_dtype = result.dtype expected_dtype = result.dtype
result = result * ia3_scaling result = result * ia3_scaling
@ -79,41 +68,34 @@ if is_bnb_available():
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "ia3." + rep
if is_bnb_4bit_available(): if is_bnb_4bit_available():
class Linear4bit(bnb.nn.Linear4bit, IA3Layer): class Linear4bit(torch.nn.Module, IA3Layer):
# IA3 implemented in a dense layer # IA3 implemented in a dense layer
def __init__( def __init__(
self, self,
adapter_name, base_layer: torch.nn.Module,
in_features, adapter_name: str,
out_features, is_feedforward: bool,
is_feedforward, init_ia3_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
bnb.nn.Linear4bit.__init__( super().__init__()
self, IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward
# Freezing the pre-trained weight matrix # Freezing the pre-trained weight matrix
self.weight.requires_grad = False self.get_base_layer().weight.requires_grad = False
init_ia3_weights = kwargs.pop("init_ia3_weights", True)
self.update_layer(adapter_name, init_ia3_weights) self.update_layer(adapter_name, init_ia3_weights)
self.set_adapter(adapter_name)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# note: no check for self.merged because merging is not supported (yet)
if self.disable_adapters: if self.disable_adapters:
return super().forward(x) return self.base_layer(x)
ia3_scaling = 1 ia3_scaling = 1
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
@ -125,10 +107,10 @@ if is_bnb_4bit_available():
if requires_conversion: if requires_conversion:
x = x.float() x = x.float()
if self.is_feedforward: if self.is_feedforward:
result = super().forward(x * ia3_scaling) result = self.base_layer(x * ia3_scaling)
expected_dtype = result.dtype expected_dtype = result.dtype
else: else:
result = super().forward(x) result = self.base_layer(x)
expected_dtype = result.dtype expected_dtype = result.dtype
result = result * ia3_scaling result = result * ia3_scaling
@ -140,3 +122,7 @@ if is_bnb_4bit_available():
result = result.to(expected_dtype) result = result.to(expected_dtype)
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "ia3." + rep

View File

@ -14,34 +14,43 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from typing import Tuple, Union from typing import Any, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from transformers.pytorch_utils import Conv1D
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import transpose from peft.utils import transpose
class IA3Layer(BaseTunerLayer): class IA3Layer(BaseTunerLayer):
# List all names of layers that may contain adapter weights # All names of layers that may contain adapter weights
adapter_layer_names = ["ia3_l"] adapter_layer_names = ("ia3_l",)
def __init__( def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
self, self.base_layer = base_layer
in_features: int,
out_features: int,
is_feedforward: bool,
):
self.scaling = {}
self.ia3_l = nn.ParameterDict({}) self.ia3_l = nn.ParameterDict({})
# Mark the weight as unmerged # Mark the weight as unmerged
self._disable_adapters = False self._disable_adapters = False
self.merged_adapters = [] self.merged_adapters = []
self.is_feedforward = is_feedforward
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv2d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
elif isinstance(base_layer, Conv1D):
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.is_feedforward = is_feedforward
def update_layer(self, adapter_name, init_ia3_weights): def update_layer(self, adapter_name, init_ia3_weights):
# Actual trainable parameters # Actual trainable parameters
@ -52,7 +61,7 @@ class IA3Layer(BaseTunerLayer):
self.ia3_l[adapter_name] = nn.Parameter(weight) self.ia3_l[adapter_name] = nn.Parameter(weight)
if init_ia3_weights: if init_ia3_weights:
self.reset_ia3_parameters(adapter_name) self.reset_ia3_parameters(adapter_name)
self.to(self.weight.device) self.to(self.get_base_layer().weight.device)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters)
def reset_ia3_parameters(self, adapter_name): def reset_ia3_parameters(self, adapter_name):
@ -61,35 +70,24 @@ class IA3Layer(BaseTunerLayer):
nn.init.constant_(self.ia3_l[adapter_name], 1.0) nn.init.constant_(self.ia3_l[adapter_name], 1.0)
class Linear(nn.Linear, IA3Layer): class Linear(nn.Module, IA3Layer):
# (IA)^3 implemented in a dense layer # (IA)^3 implemented in a dense layer
def __init__( def __init__(
self, self,
base_layer: nn.Module,
adapter_name: str, adapter_name: str,
in_features: int,
out_features: int,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later
init_ia3_weights: bool = True, # whether to initialize IA3 weights
**kwargs, **kwargs,
) -> None: ) -> None:
init_ia3_weights = kwargs.pop("init_ia3_weights", True) super().__init__()
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
nn.Linear.__init__(self, in_features, out_features, **kwargs)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.fan_in_fan_out = fan_in_fan_out self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T
self.is_target_conv_1d_layer = is_target_conv_1d_layer self.is_target_conv_1d_layer = is_target_conv_1d_layer
self._active_adapter = adapter_name
nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, init_ia3_weights) self.update_layer(adapter_name, init_ia3_weights)
self.set_adapter(adapter_name)
def update_layer(self, adapter_name, init_ia3_weights): def update_layer(self, adapter_name, init_ia3_weights):
# Actual trainable parameters # Actual trainable parameters
@ -100,10 +98,10 @@ class Linear(nn.Linear, IA3Layer):
self.ia3_l[adapter_name] = nn.Parameter(weight) self.ia3_l[adapter_name] = nn.Parameter(weight)
if init_ia3_weights: if init_ia3_weights:
self.reset_ia3_parameters(adapter_name) self.reset_ia3_parameters(adapter_name)
self.to(self.weight.device) self.to(self.get_base_layer().weight.device)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters)
def merge(self, safe_merge: bool = False) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights into the base weights Merge the active adapter weights into the base weights
@ -112,6 +110,9 @@ class Linear(nn.Linear, IA3Layer):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
@ -119,26 +120,28 @@ class Linear(nn.Linear, IA3Layer):
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters: if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter in self.ia3_l.keys(): if active_adapter in self.ia3_l.keys():
base_layer = self.get_base_layer()
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out)
if safe_merge: if safe_merge:
orig_weights = transpose(self.weight, self.fan_in_fan_out).clone() orig_weights = base_layer.weight.data
orig_weights = torch.mul(orig_weights.data, self.ia3_l[active_adapter].data) orig_weights = torch.mul(orig_weights, ia3_l)
if not torch.isfinite(orig_weights).all(): if not torch.isfinite(orig_weights).all():
raise ValueError( raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.weight.data = orig_weights base_layer.weight.data = orig_weights
self.weight = transpose(self.weight, self.fan_in_fan_out)
else: else:
self.weight = transpose(self.weight, self.fan_in_fan_out) base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l)
self.weight.data = torch.mul(self.weight.data, self.ia3_l[active_adapter].data)
self.weight = transpose(self.weight, self.fan_in_fan_out)
if not self.is_feedforward and (self.bias is not None): if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
self.bias.data = torch.mul(self.bias.data, scaling.data) base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
@ -151,27 +154,24 @@ class Linear(nn.Linear, IA3Layer):
while len(self.merged_adapters) > 0: while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop() active_adapter = self.merged_adapters.pop()
if active_adapter in self.ia3_l.keys(): if active_adapter in self.ia3_l.keys():
self.weight = transpose(self.weight, self.fan_in_fan_out) base_layer = self.get_base_layer()
# divide by (IA)^3 vector. Add tolerace to avoid division by zero # Add tolerace to avoid division by zero
self.weight.data = torch.div(self.weight.data, self.ia3_l[active_adapter].data + 1e-8) ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8
self.weight = transpose(self.weight, self.fan_in_fan_out) base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l)
if not self.is_feedforward and (self.bias is not None): if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
self.bias.data = torch.div(self.bias.data, scaling.data + 1e-8) base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8)
def _linear(self, input: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = previous_dtype = x.dtype dtype = previous_dtype = x.dtype
if self.disable_adapters: if self.disable_adapters:
if self.merged: if self.merged:
self.unmerge() self.unmerge()
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
elif self.merged: elif self.merged:
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
else: else:
ia3_scaling = 1 ia3_scaling = 1
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
@ -182,46 +182,34 @@ class Linear(nn.Linear, IA3Layer):
if self.is_feedforward: if self.is_feedforward:
x = x.to(dtype) x = x.to(dtype)
# TODO: self.weight.dtype can be != self.ia3_l[self.active_adapters].dtype # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
# e.g. bf16 vs fp32. Is that okay? # e.g. bf16 vs fp32. Is that okay?
interm = (x * ia3_scaling).to(self.weight.dtype) interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype)
result = self._linear(interm) result = self.base_layer(interm, *args, **kwargs)
else: else:
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
result = result.to(dtype) * ia3_scaling result = result.to(dtype) * ia3_scaling
result = result.to(previous_dtype) result = result.to(previous_dtype)
return result return result
class Conv2d(nn.Conv2d, IA3Layer): class Conv2d(nn.Module, IA3Layer):
def __init__( def __init__(
self, self,
base_layer: nn.Module,
adapter_name: str, adapter_name: str,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
init_ia3_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
init_ia3_weights = kwargs.pop("init_ia3_weights", True) super().__init__()
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
IA3Layer.__init__(self, in_features=in_channels, out_features=out_channels, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.fan_in_fan_out = fan_in_fan_out self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out: self._active_adapter = adapter_name
self.weight.data = self.weight.data.T
nn.Conv2d.reset_parameters(self)
self.update_layer(adapter_name, init_ia3_weights) self.update_layer(adapter_name, init_ia3_weights)
self.set_adapter(adapter_name)
def update_layer(self, adapter_name, init_ia3_weights): def update_layer(self, adapter_name, init_ia3_weights):
# Actual trainable parameters # Actual trainable parameters
@ -232,10 +220,10 @@ class Conv2d(nn.Conv2d, IA3Layer):
self.ia3_l[adapter_name] = nn.Parameter(weight) self.ia3_l[adapter_name] = nn.Parameter(weight)
if init_ia3_weights: if init_ia3_weights:
self.reset_ia3_parameters(adapter_name) self.reset_ia3_parameters(adapter_name)
self.to(self.weight.device) self.to(self.get_base_layer().weight.device)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters)
def merge(self, safe_merge: bool = False) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights into the base weights Merge the active adapter weights into the base weights
@ -244,6 +232,9 @@ class Conv2d(nn.Conv2d, IA3Layer):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
@ -251,27 +242,31 @@ class Conv2d(nn.Conv2d, IA3Layer):
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters: if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter in self.ia3_l.keys(): if active_adapter in self.ia3_l.keys():
base_layer = self.get_base_layer()
ia3_scaling = self.ia3_l[active_adapter].data ia3_scaling = self.ia3_l[active_adapter].data
if not self.is_feedforward: if not self.is_feedforward:
ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) ia3_scaling = ia3_scaling.permute(1, 0, 2, 3)
if safe_merge: if safe_merge:
output_weight = torch.mul(self.weight.data, ia3_scaling).clone() output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone()
if not torch.isfinite(output_weight).all(): if not torch.isfinite(output_weight).all():
raise ValueError( raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.weight.data = output_weight base_layer.weight.data = output_weight
else: else:
self.weight.data = torch.mul(self.weight.data, ia3_scaling) base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling)
if not self.is_feedforward and (self.bias is not None): if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
self.bias.data = torch.mul(self.bias.data, scaling.data) base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
@ -284,36 +279,26 @@ class Conv2d(nn.Conv2d, IA3Layer):
while len(self.merged_adapters) > 0: while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop() active_adapter = self.merged_adapters.pop()
if active_adapter in self.ia3_l.keys(): if active_adapter in self.ia3_l.keys():
base_layer = self.get_base_layer()
# divide by (IA)^3 vector. Add tolerace to avoid division by zero # divide by (IA)^3 vector. Add tolerace to avoid division by zero
ia3_scaling = self.ia3_l[active_adapter].data ia3_scaling = self.ia3_l[active_adapter].data
if not self.is_feedforward: if not self.is_feedforward:
ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) ia3_scaling = ia3_scaling.permute(1, 0, 2, 3)
self.weight.data = torch.div(self.weight.data, ia3_scaling + 1e-8) base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8)
if not self.is_feedforward and (self.bias is not None): if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(self.bias.shape) scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
self.bias.data = torch.mul(self.bias.data, scaling.data) base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)
def _conv2d(self, input: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
return F.conv2d( dtype = previous_dtype = x.dtype
input,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters: if self.disable_adapters:
if self.merged: if self.merged:
self.unmerge() self.unmerge()
result = self._conv2d(x) result = self.base_layer(x, *args, **kwargs)
elif self.merged: elif self.merged:
result = self._conv2d(x) result = self.base_layer(x, *args, **kwargs)
else: else:
ia3_scaling = 1 ia3_scaling = 1
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
@ -324,12 +309,12 @@ class Conv2d(nn.Conv2d, IA3Layer):
if self.is_feedforward: if self.is_feedforward:
x = x.to(dtype) x = x.to(dtype)
# TODO: self.weight.dtype can be != self.ia3_l[self.active_adapters].dtype # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
# e.g. bf16 vs fp32. Is that okay? # e.g. bf16 vs fp32. Is that okay?
interm = (x * ia3_scaling).to(self.weight.dtype) interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype)
result = self._conv2d(interm) result = self.base_layer(interm, *args, **kwargs)
else: else:
result = self._conv2d(x) result = self.base_layer(x, *args, **kwargs)
result = result.to(dtype) * ia3_scaling result = result.to(dtype) * ia3_scaling
result = result.to(previous_dtype) result = result.to(previous_dtype)

View File

@ -17,12 +17,13 @@ import re
import warnings import warnings
from dataclasses import asdict from dataclasses import asdict
from enum import Enum from enum import Enum
from typing import List, Optional
import torch import torch
from transformers.pytorch_utils import Conv1D from transformers.pytorch_utils import Conv1D
from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTuner, check_target_module_exists from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
from peft.utils import ( from peft.utils import (
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
@ -77,17 +78,23 @@ class IA3Model(BaseTuner):
- **peft_config** ([`ia3Config`]): The configuration of the (IA)^3 model. - **peft_config** ([`ia3Config`]): The configuration of the (IA)^3 model.
""" """
prefix: str = "ia3_"
def __init__(self, model, config, adapter_name): def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name) super().__init__(model, config, adapter_name)
@staticmethod @staticmethod
def _create_new_module(ia3_config, adapter_name, target, **kwargs): def _create_new_module(ia3_config, adapter_name, target, **kwargs):
bias = hasattr(target, "bias") and target.bias is not None
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
is_feedforward = kwargs.pop("is_feedforward", False) is_feedforward = kwargs.pop("is_feedforward", False)
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy() eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update( eightbit_kwargs.update(
{ {
@ -97,15 +104,8 @@ class IA3Model(BaseTuner):
"index": target.index, "index": target.index,
} }
) )
new_module = Linear8bitLt( new_module = Linear8bitLt(target, adapter_name, is_feedforward=is_feedforward, **eightbit_kwargs)
adapter_name, elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit):
target.in_features,
target.out_features,
is_feedforward,
bias=bias,
**eightbit_kwargs,
)
elif loaded_in_4bit and isinstance(target, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy() fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update( fourbit_kwargs.update(
{ {
@ -114,56 +114,31 @@ class IA3Model(BaseTuner):
"quant_type": target.weight.quant_type, "quant_type": target.weight.quant_type,
} }
) )
new_module = Linear4bit( new_module = Linear4bit(target, adapter_name, is_feedforward=is_feedforward, **fourbit_kwargs)
adapter_name,
target.in_features,
target.out_features,
is_feedforward,
bias=bias,
**fourbit_kwargs,
)
elif isinstance(target, torch.nn.Conv2d): elif isinstance(target, torch.nn.Conv2d):
out_channels, in_channels = target.weight.size()[:2] new_module = Conv2d(target, adapter_name, is_feedforward=is_feedforward, **kwargs)
kernel_size = target.weight.size()[2:] elif isinstance(target_base_layer, torch.nn.Linear):
stride = target.stride if kwargs["fan_in_fan_out"]:
padding = target.padding warnings.warn(
new_module = Conv2d( "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
adapter_name=adapter_name, "Setting fan_in_fan_out to False."
in_channels=in_channels, )
out_channels=out_channels, kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False
kernel_size=kernel_size, new_module = Linear(target, adapter_name, is_feedforward=is_feedforward, **kwargs)
stride=stride, elif isinstance(target_base_layer, Conv1D):
padding=padding, if not kwargs["fan_in_fan_out"]:
is_feedforward=is_feedforward, warnings.warn(
**kwargs, "fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True
new_module = Linear(
target, adapter_name, is_feedforward=is_feedforward, is_target_conv_1d_layer=True, **kwargs
) )
else: else:
if isinstance(target, torch.nn.Linear): raise ValueError(
in_features, out_features = target.in_features, target.out_features f"Target module {target} is not supported. "
if kwargs["fan_in_fan_out"]: f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported."
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
kwargs["is_target_conv_1d_layer"] = True # useful for unloading later
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported."
)
new_module = Linear(
adapter_name, in_features, out_features, is_feedforward=is_feedforward, bias=bias, **kwargs
) )
return new_module return new_module
@ -173,7 +148,7 @@ class IA3Model(BaseTuner):
def _mark_only_adapters_as_trainable(self) -> None: def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters(): for n, p in self.model.named_parameters():
if "ia3_" not in n: if self.prefix not in n:
p.requires_grad = False p.requires_grad = False
def _create_and_replace( def _create_and_replace(
@ -200,21 +175,16 @@ class IA3Model(BaseTuner):
"is_feedforward": is_feedforward, "is_feedforward": is_feedforward,
} }
if isinstance(target, IA3Layer): if isinstance(target, Conv2d):
if target.is_feedforward != is_feedforward: target.update_layer(
raise ValueError( adapter_name,
"New adapter should have the same value for `is_feedforward` as previously added adapter." ia3_config.init_ia3_weights,
) )
if isinstance(target, torch.nn.Conv2d): elif isinstance(target, Linear):
target.update_layer_conv2d( target.update_layer(
adapter_name, adapter_name,
ia3_config.init_ia3_weights, ia3_config.init_ia3_weights,
) )
else: # Linear
target.update_layer(
adapter_name,
ia3_config.init_ia3_weights,
)
else: else:
new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs) new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter: if adapter_name != self.active_adapter:
@ -234,19 +204,29 @@ class IA3Model(BaseTuner):
is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules) is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules)
return is_feedforward return is_feedforward
@staticmethod def _replace_module(self, parent, child_name, new_module, child):
def _replace_module(parent, child_name, new_module, child):
setattr(parent, child_name, new_module) setattr(parent, child_name, new_module)
new_module.weight = child.weight
if child.bias is not None: # child layer wraps the original module, unpack it
new_module.bias = child.bias if hasattr(child, "base_layer"):
child = child.base_layer
# layers with base_layer don't need the weight to be copied, as they have a reference already
if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias
if getattr(child, "state", None) is not None: if getattr(child, "state", None) is not None:
new_module.state = child.state if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device) new_module.to(child.weight.device)
# dispatch to correct device # dispatch to correct device
for name, module in new_module.named_modules(): for name, module in new_module.named_modules():
if "ia3_" in name: if self.prefix in name:
module.to(child.weight.device) module.to(child.weight.device)
def __getattr__(self, name: str): def __getattr__(self, name: str):
@ -297,7 +277,9 @@ class IA3Model(BaseTuner):
] ]
return peft_config return peft_config
def merge_and_unload(self, safe_merge: bool = False): def _unload_and_optionally_merge(
self, merge: bool = True, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
):
r""" r"""
This method merges the (IA)^3 layers into the base model. This is needed if someone wants to use the base model This method merges the (IA)^3 layers into the base model. This is needed if someone wants to use the base model
as a standalone model. as a standalone model.
@ -307,6 +289,9 @@ class IA3Model(BaseTuner):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
""" """
if getattr(self.model, "is_loaded_in_8bit", False): if getattr(self.model, "is_loaded_in_8bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode") raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode")
@ -314,38 +299,75 @@ class IA3Model(BaseTuner):
if getattr(self.model, "is_loaded_in_4bit", False): if getattr(self.model, "is_loaded_in_4bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode") raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode")
key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key] key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list: for key in key_list:
try: try:
parent, target, target_name = _get_submodules(self.model, key) parent, target, target_name = _get_submodules(self.model, key)
except AttributeError: except AttributeError:
continue continue
# save any additional trainable modules part of `modules_to_save` if hasattr(target, "base_layer"):
if isinstance(target, ModulesToSaveWrapper): if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter]) setattr(parent, target_name, target.modules_to_save[target.active_adapter])
continue
if not isinstance(target, IA3Layer):
continue
if isinstance(target, torch.nn.Conv2d):
new_module = torch.nn.Conv2d(
target.in_channels,
target.out_channels,
kernel_size=target.kernel_size,
stride=target.stride,
padding=target.padding,
dilation=target.dilation,
)
else:
bias = target.bias is not None
if getattr(target, "is_target_conv_1d_layer", False):
new_module = Conv1D(target.out_features, target.in_features)
else:
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
target.merge(safe_merge=safe_merge)
self._replace_module(parent, target_name, new_module, target)
return self.model return self.model
def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None):
r"""
This method merges the IA³ layers into the base model. This is needed if someone wants to use the base model as
a standalone model.
Args:
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
Example:
```py
>>> from transformers import AutoModelForCausalLM
>>> from peft import PeftModel
>>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
>>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample"
>>> model = PeftModel.from_pretrained(base_model, peft_model_id)
>>> merged_model = model.merge_and_unload()
```
"""
return self._unload_and_optionally_merge(safe_merge=safe_merge, adapter_names=adapter_names)
def unload(self):
"""
Gets back the base model by removing all the IA³ modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)
def delete_adapter(self, adapter_name: str):
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in self.peft_config:
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]
self.active_adapter = new_adapter or []

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Set, Tuple, Union from typing import Any, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -23,13 +23,14 @@ import torch.nn.functional as F
from peft.tuners.lycoris_utils import LycorisLayer from peft.tuners.lycoris_utils import LycorisLayer
class LoHaLayer(LycorisLayer, nn.Module): class LoHaLayer(nn.Module, LycorisLayer):
# List all names of layers that may contain adapter weights # All names of layers that may contain adapter weights
adapter_layer_names = ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2"] adapter_layer_names = ("hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2")
# other_param_names is defined on parent class
def __init__(self): def __init__(self, base_layer: nn.Module):
LycorisLayer.__init__(self) super().__init__()
super(nn.Module, self).__init__() LycorisLayer.__init__(self, base_layer)
# LoHa info # LoHa info
self.hada_w1_a = nn.ParameterDict({}) self.hada_w1_a = nn.ParameterDict({})
@ -75,6 +76,21 @@ class LoHaLayer(LycorisLayer, nn.Module):
nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5)) nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5)) nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5))
def reset_adapter_parameters_random(self, adapter_name: str):
# Original implementation performs initialization with normal distribution
# https://github.com/KohakuBlueleaf/LyCORIS/blob/3549fdef8f564761d68b695a08ef88b1122fdedc/lycoris/modules/loha.py#L158
# FedPara paper proposes to perform He initialization, let's stick with it
# It is enough to initialize only single matrix with zeros to make adapter do nothing after initialization
if adapter_name in self.hada_w1_a.keys():
nn.init.kaiming_uniform_(self.hada_w1_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w1_b[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w2_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w2_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.hada_t1.keys():
nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5))
def update_layer( def update_layer(
self, self,
adapter_name: str, adapter_name: str,
@ -106,16 +122,20 @@ class LoHaLayer(LycorisLayer, nn.Module):
self.module_dropout[adapter_name] = module_dropout self.module_dropout[adapter_name] = module_dropout
# Determine shape of LoHa weights # Determine shape of LoHa weights
if isinstance(self, nn.Linear): base_layer = self.get_base_layer()
shape = tuple(self.weight.shape) if isinstance(base_layer, nn.Linear):
elif isinstance(self, nn.Conv2d): shape = tuple(base_layer.weight.shape)
use_effective_conv2d = use_effective_conv2d and self.kernel_size != (1, 1) elif isinstance(base_layer, nn.Conv2d):
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
if use_effective_conv2d: if use_effective_conv2d:
shape = (self.out_channels, self.in_channels, *self.kernel_size) shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size)
else: else:
shape = (self.out_channels, self.in_channels * self.kernel_size[0] * self.kernel_size[1]) shape = (
base_layer.out_channels,
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
)
else: else:
raise TypeError(f"LoHa is not implemented for {type(self).__name__} layer") raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}")
# Create weights with provided shape # Create weights with provided shape
self.create_adapter_parameters(adapter_name, r, shape) self.create_adapter_parameters(adapter_name, r, shape)
@ -123,9 +143,11 @@ class LoHaLayer(LycorisLayer, nn.Module):
# Initialize weights # Initialize weights
if init_weights: if init_weights:
self.reset_adapter_parameters(adapter_name) self.reset_adapter_parameters(adapter_name)
else:
self.reset_adapter_parameters_random(adapter_name)
# Move new weights to device # Move new weights to device
weight = getattr(self, "weight", None) weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None: if weight is not None:
# the layer is already completely initialized, this is an update # the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex: if weight.dtype.is_floating_point or weight.dtype.is_complex:
@ -155,7 +177,8 @@ class LoHaLayer(LycorisLayer, nn.Module):
scale=torch.tensor(self.scaling[adapter_name]), scale=torch.tensor(self.scaling[adapter_name]),
) )
weight = weight.reshape(self.weight.shape) base_layer = self.get_base_layer()
weight = weight.reshape(base_layer.weight.shape)
# Perform rank dropout during training - drop rows of addition weights # Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name] rank_dropout = self.rank_dropout[adapter_name]
@ -170,96 +193,107 @@ class LoHaLayer(LycorisLayer, nn.Module):
return weight return weight
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype
class Linear(LoHaLayer, nn.Linear): if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# Execute all the adapters
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
module_dropout = self.module_dropout[active_adapter]
# Modify current execution weights
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs)
result = result.to(previous_dtype)
return result
class Linear(LoHaLayer):
"""LoHa implemented in Linear layer""" """LoHa implemented in Linear layer"""
def __init__( def __init__(
self, self,
in_features: int, base_layer: nn.Module,
out_features: int,
bias: bool = True,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
adapter_name: str = "default", adapter_name: str = "default",
r: int = 0, r: int = 0,
alpha: float = 0.0, alpha: float = 0.0,
rank_dropout: float = 0.0, rank_dropout: float = 0.0,
module_dropout: float = 0.0, module_dropout: float = 0.0,
init_weights: bool = True,
**kwargs, **kwargs,
): ):
init_weights = kwargs.pop("init_weights", True) super().__init__(base_layer)
self._init_empty_weights(nn.Linear, in_features, out_features, bias, device=device, dtype=dtype)
LoHaLayer.__init__(self)
# Create adapter and set it active # Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs)
self.set_adapter(adapter_name)
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: def _get_delta_activations(
return F.linear(input, weight, bias=self.bias) self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
return F.linear(input, delta_weight)
def __repr__(self) -> str:
rep = super().__repr__()
return "loha." + rep
class Conv2d(LoHaLayer, nn.Conv2d): class Conv2d(LoHaLayer):
"""LoHa implemented in Conv2d layer""" """LoHa implemented in Conv2d layer"""
def __init__( def __init__(
self, self,
in_channels: int, base_layer: nn.Module,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
adapter_name: str = "default", adapter_name: str = "default",
r: int = 0, r: int = 0,
alpha: float = 0.0, alpha: float = 0.0,
rank_dropout: float = 0.0, rank_dropout: float = 0.0,
module_dropout: float = 0.0, module_dropout: float = 0.0,
use_effective_conv2d: bool = False, use_effective_conv2d: bool = False,
init_weights: bool = True,
**kwargs, **kwargs,
): ):
init_weights = kwargs.pop("init_weights", True) super().__init__(base_layer)
self._init_empty_weights(
nn.Conv2d,
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
LoHaLayer.__init__(self)
# Create adapter and set it active # Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer( self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
) )
self.set_adapter(adapter_name)
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
base_layer = self.get_base_layer()
return F.conv2d( return F.conv2d(
input, input,
weight, delta_weight,
bias=self.bias, stride=base_layer.stride,
stride=self.stride, padding=base_layer.padding,
padding=self.padding, dilation=base_layer.dilation,
dilation=self.dilation, groups=base_layer.groups,
groups=self.groups,
) )
def __repr__(self) -> str:
rep = super().__repr__()
return "loha." + rep
# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9 # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9

View File

@ -13,11 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Type import re
from itertools import chain
from typing import Dict, Type, Union
import torch import torch
from torch import nn
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
from ..lycoris_utils import LycorisTuner
from .layer import Conv2d, Linear, LoHaLayer from .layer import Conv2d, Linear, LoHaLayer
@ -82,3 +86,31 @@ class LoHaModel(LycorisTuner):
torch.nn.Conv2d: Conv2d, torch.nn.Conv2d: Conv2d,
torch.nn.Linear: Linear, torch.nn.Linear: Linear,
} }
def _create_and_replace(
self,
config: LycorisConfig,
adapter_name: str,
target: Union[LoHaLayer, nn.Module],
target_name: str,
parent: nn.Module,
current_key: str,
**optional_kwargs,
) -> None:
"""
A private method to create and replace the target module with the adapter module.
"""
# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name)
kwargs = config.to_dict()
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
if isinstance(target, LoHaLayer):
target.update_layer(adapter_name, **kwargs)
else:
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Set, Tuple, Union from typing import Any, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -23,9 +23,9 @@ import torch.nn.functional as F
from peft.tuners.lycoris_utils import LycorisLayer from peft.tuners.lycoris_utils import LycorisLayer
class LoKrLayer(LycorisLayer, nn.Module): class LoKrLayer(nn.Module, LycorisLayer):
# List all names of layers that may contain adapter weights # All names of layers that may contain adapter weights
adapter_layer_names = [ adapter_layer_names = (
"lokr_w1", "lokr_w1",
"lokr_w1_a", "lokr_w1_a",
"lokr_w1_b", "lokr_w1_b",
@ -33,11 +33,12 @@ class LoKrLayer(LycorisLayer, nn.Module):
"lokr_w2_a", "lokr_w2_a",
"lokr_w2_b", "lokr_w2_b",
"lokr_t2", "lokr_t2",
] )
# other_param_names is defined on parent class
def __init__(self): def __init__(self, base_layer: nn.Module) -> None:
LycorisLayer.__init__(self) super().__init__()
super(nn.Module, self).__init__() LycorisLayer.__init__(self, base_layer)
# LoKr info # LoKr info
self.lokr_w1 = nn.ParameterDict({}) self.lokr_w1 = nn.ParameterDict({})
@ -110,6 +111,22 @@ class LoKrLayer(LycorisLayer, nn.Module):
if adapter_name in self.lokr_t2: if adapter_name in self.lokr_t2:
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
def reset_adapter_parameters_random(self, adapter_name: str):
if adapter_name in self.lokr_w1:
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.lokr_w2:
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.lokr_t2:
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
def update_layer( def update_layer(
self, self,
adapter_name: str, adapter_name: str,
@ -142,10 +159,11 @@ class LoKrLayer(LycorisLayer, nn.Module):
self.scaling[adapter_name] = alpha / r self.scaling[adapter_name] = alpha / r
self.rank_dropout[adapter_name] = rank_dropout self.rank_dropout[adapter_name] = rank_dropout
self.module_dropout[adapter_name] = module_dropout self.module_dropout[adapter_name] = module_dropout
base_layer = self.get_base_layer()
# Determine shape of LoKr weights # Determine shape of LoKr weights
if isinstance(self, nn.Linear): if isinstance(base_layer, nn.Linear):
in_dim, out_dim = self.in_features, self.out_features in_dim, out_dim = base_layer.in_features, base_layer.out_features
in_m, in_n = factorization(in_dim, decompose_factor) in_m, in_n = factorization(in_dim, decompose_factor)
out_l, out_k = factorization(out_dim, decompose_factor) out_l, out_k = factorization(out_dim, decompose_factor)
@ -154,9 +172,9 @@ class LoKrLayer(LycorisLayer, nn.Module):
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2) use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2)
use_effective_conv2d = False use_effective_conv2d = False
elif isinstance(self, nn.Conv2d): elif isinstance(base_layer, nn.Conv2d):
in_dim, out_dim = self.in_channels, self.out_channels in_dim, out_dim = base_layer.in_channels, base_layer.out_channels
k_size = self.kernel_size k_size = base_layer.kernel_size
in_m, in_n = factorization(in_dim, decompose_factor) in_m, in_n = factorization(in_dim, decompose_factor)
out_l, out_k = factorization(out_dim, decompose_factor) out_l, out_k = factorization(out_dim, decompose_factor)
@ -164,9 +182,9 @@ class LoKrLayer(LycorisLayer, nn.Module):
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
use_effective_conv2d = use_effective_conv2d and self.kernel_size != (1, 1) use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
else: else:
raise TypeError(f"LoKr is not implemented for {type(self).__name__} layer") raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}")
# Create weights with provided shape # Create weights with provided shape
self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d) self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d)
@ -174,9 +192,11 @@ class LoKrLayer(LycorisLayer, nn.Module):
# Initialize weights # Initialize weights
if init_weights: if init_weights:
self.reset_adapter_parameters(adapter_name) self.reset_adapter_parameters(adapter_name)
else:
self.reset_adapter_parameters_random(adapter_name)
# Move new weights to device # Move new weights to device
weight = getattr(self, "weight", None) weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None: if weight is not None:
# the layer is already completely initialized, this is an update # the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex: if weight.dtype.is_floating_point or weight.dtype.is_complex:
@ -201,7 +221,7 @@ class LoKrLayer(LycorisLayer, nn.Module):
# Make weights with Kronecker product # Make weights with Kronecker product
weight = make_kron(w1, w2) weight = make_kron(w1, w2)
weight = weight.reshape(self.weight.shape) weight = weight.reshape(self.get_base_layer().weight.shape)
# Perform rank dropout during training - drop rows of addition weights # Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name] rank_dropout = self.rank_dropout[adapter_name]
@ -213,15 +233,39 @@ class LoKrLayer(LycorisLayer, nn.Module):
return weight return weight
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype
class Linear(LoKrLayer, nn.Linear): if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# Execute all the adapters
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
module_dropout = self.module_dropout[active_adapter]
# Modify current execution weights
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs)
result = result.to(previous_dtype)
return result
class Linear(LoKrLayer):
"""LoKr implemented in Linear layer""" """LoKr implemented in Linear layer"""
def __init__( def __init__(
self, self,
in_features: int, base_layer: nn.Module,
out_features: int,
bias: bool = True,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
adapter_name: str = "default", adapter_name: str = "default",
@ -229,35 +273,33 @@ class Linear(LoKrLayer, nn.Linear):
alpha: float = 0.0, alpha: float = 0.0,
rank_dropout: float = 0.0, rank_dropout: float = 0.0,
module_dropout: float = 0.0, module_dropout: float = 0.0,
init_weights: bool = True,
**kwargs, **kwargs,
): ):
init_weights = kwargs.pop("init_weights", True) super().__init__(base_layer)
self._init_empty_weights(nn.Linear, in_features, out_features, bias, device=device, dtype=dtype)
LoKrLayer.__init__(self)
# Create adapter and set it active # Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs)
self.set_adapter(adapter_name)
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: def _get_delta_activations(
return F.linear(input, weight, bias=self.bias) self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
return F.linear(input, delta_weight)
def __repr__(self) -> str:
rep = super().__repr__()
return "lokr." + rep
class Conv2d(LoKrLayer, nn.Conv2d): class Conv2d(LoKrLayer):
"""LoKr implemented in Conv2d layer""" """LoKr implemented in Conv2d layer"""
def __init__( def __init__(
self, self,
in_channels: int, base_layer: nn.Module,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
adapter_name: str = "default", adapter_name: str = "default",
@ -266,43 +308,36 @@ class Conv2d(LoKrLayer, nn.Conv2d):
rank_dropout: float = 0.0, rank_dropout: float = 0.0,
module_dropout: float = 0.0, module_dropout: float = 0.0,
use_effective_conv2d: bool = False, use_effective_conv2d: bool = False,
init_weights: bool = True,
**kwargs, **kwargs,
): ):
init_weights = kwargs.pop("init_weights", True) super().__init__(base_layer)
self._init_empty_weights(
nn.Conv2d,
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
LoKrLayer.__init__(self)
# Create adapter and set it active # Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer( self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
) )
self.set_adapter(adapter_name)
def _op(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
base_layer = self.get_base_layer()
return F.conv2d( return F.conv2d(
input, input,
weight, delta_weight,
bias=self.bias, stride=base_layer.stride,
stride=self.stride, padding=base_layer.padding,
padding=self.padding, dilation=base_layer.dilation,
dilation=self.dilation, groups=base_layer.groups,
groups=self.groups,
) )
def __repr__(self) -> str:
rep = super().__repr__()
return "lokr." + rep
# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11 # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11

View File

@ -13,11 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Type import re
from itertools import chain
from typing import Dict, Type, Union
import torch import torch
from torch import nn
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
from ..lycoris_utils import LycorisTuner
from .layer import Conv2d, Linear, LoKrLayer from .layer import Conv2d, Linear, LoKrLayer
@ -83,3 +87,31 @@ class LoKrModel(LycorisTuner):
torch.nn.Conv2d: Conv2d, torch.nn.Conv2d: Conv2d,
torch.nn.Linear: Linear, torch.nn.Linear: Linear,
} }
def _create_and_replace(
self,
config: LycorisConfig,
adapter_name: str,
target: Union[LoKrLayer, nn.Module],
target_name: str,
parent: nn.Module,
current_key: str,
**optional_kwargs,
) -> None:
"""
A private method to create and replace the target module with the adapter module.
"""
# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name)
kwargs = config.to_dict()
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
if isinstance(target, LoKrLayer):
target.update_layer(adapter_name, **kwargs)
else:
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from typing import List, Optional
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
@ -30,22 +31,20 @@ if is_bnb_available():
# Lora implemented in a dense layer # Lora implemented in a dense layer
def __init__( def __init__(
self, self,
adapter_name, base_layer: torch.nn.Module,
base_layer, adapter_name: str,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features) LoraLayer.__init__(self, base_layer)
self.base_layer = base_layer
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def merge(self, safe_merge: bool = False): def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights into the base weights Merge the active adapter weights into the base weights
@ -54,6 +53,9 @@ if is_bnb_available():
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
@ -61,7 +63,10 @@ if is_bnb_available():
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters: if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys(): if active_adapter not in self.lora_A.keys():
continue continue
warnings.warn( warnings.warn(
@ -69,8 +74,8 @@ if is_bnb_available():
) )
lora_data = self.get_delta_weight(active_adapter) lora_data = self.get_delta_weight(active_adapter)
weight = self.base_layer.weight weight = self.get_base_layer().weight
state = self.base_layer.state state = self.get_base_layer().state
if state.SCB is None: if state.SCB is None:
state.SCB = weight.SCB state.SCB = weight.SCB
@ -90,7 +95,7 @@ if is_bnb_available():
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.base_layer.weight = bnb.nn.Int8Params( self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device) ).to(weight.device)
state.reset_grads() state.reset_grads()
@ -110,8 +115,8 @@ if is_bnb_available():
) )
lora_data = self.get_delta_weight(active_adapter) lora_data = self.get_delta_weight(active_adapter)
weight = self.base_layer.weight weight = self.get_base_layer().weight
state = self.base_layer.state state = self.get_base_layer().state
if state.SCB is None: if state.SCB is None:
state.SCB = weight.SCB state.SCB = weight.SCB
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
@ -124,7 +129,7 @@ if is_bnb_available():
output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
self.base_layer.weight = bnb.nn.Int8Params( self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device) ).to(weight.device)
state.reset_grads() state.reset_grads()
@ -169,6 +174,10 @@ if is_bnb_available():
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
if is_bnb_4bit_available(): if is_bnb_4bit_available():
@ -176,22 +185,20 @@ if is_bnb_4bit_available():
# Lora implemented in a dense layer # Lora implemented in a dense layer
def __init__( def __init__(
self, self,
adapter_name, base_layer: torch.nn.Module,
base_layer, adapter_name: str,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features) LoraLayer.__init__(self, base_layer)
self.base_layer = base_layer
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def merge(self, safe_merge: bool = False): def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights into the base weights Merge the active adapter weights into the base weights
@ -200,6 +207,9 @@ if is_bnb_4bit_available():
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
@ -207,14 +217,17 @@ if is_bnb_4bit_available():
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters: if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys(): if active_adapter not in self.lora_A.keys():
continue continue
warnings.warn( warnings.warn(
"Merge lora module to 4-bit linear may get different generations due to rounding errors." "Merge lora module to 4-bit linear may get different generations due to rounding errors."
) )
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
weight = self.base_layer.weight weight = self.get_base_layer().weight
kwargs = weight.__dict__ kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter) lora_data = self.get_delta_weight(active_adapter)
@ -224,7 +237,7 @@ if is_bnb_4bit_available():
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device weight.device
) )
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
@ -241,11 +254,11 @@ if is_bnb_4bit_available():
warnings.warn( warnings.warn(
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors." "Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
) )
weight = self.base_layer.weight weight = self.get_base_layer().weight
kwargs = weight.__dict__ kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter) lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device weight.device
) )
@ -262,11 +275,11 @@ if is_bnb_4bit_available():
if self.disable_adapters: if self.disable_adapters:
if self.merged: if self.merged:
self.unmerge() self.unmerge()
result = self.base_layer.forward(x, *args, **kwargs) result = self.base_layer(x, *args, **kwargs)
elif self.merged: elif self.merged:
result = self.base_layer.forward(x, *args, **kwargs) result = self.base_layer(x, *args, **kwargs)
else: else:
result = self.base_layer.forward(x, *args, **kwargs) result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here. # As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop # The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with # does not work on a manipulated view. This issue may be solved with
@ -294,3 +307,7 @@ if is_bnb_4bit_available():
result += output result += output
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep

View File

@ -21,22 +21,21 @@ from peft.tuners.lora.layer import LoraLayer
class QuantLinear(torch.nn.Module, LoraLayer): class QuantLinear(torch.nn.Module, LoraLayer):
def __init__( def __init__(
self, self,
adapter_name, base_layer,
quant_linear_module, adapter_name: str,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
): ):
torch.nn.Module.__init__(self) super().__init__()
LoraLayer.__init__( LoraLayer.__init__(self, base_layer)
self, in_features=quant_linear_module.infeatures, out_features=quant_linear_module.outfeatures
) # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
self.quant_linear_module = quant_linear_module # for backwards compatibility
self.weight = quant_linear_module.qweight self.quant_linear_module = base_layer
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported # note: logic differs from default Linear because merging is not supported
@ -65,6 +64,10 @@ class QuantLinear(torch.nn.Module, LoraLayer):
result += output result += output
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
# TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102 # TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102
# def reset_lora_parameters(self, adapter_name): # def reset_lora_parameters(self, adapter_name):
# if adapter_name in self.lora_A.keys(): # if adapter_name in self.lora_A.keys():

View File

@ -15,21 +15,25 @@
import math import math
import warnings import warnings
from typing import Optional, Tuple, Union from typing import Any, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils.other import transpose from peft.utils.other import transpose
class LoraLayer(BaseTunerLayer): class LoraLayer(BaseTunerLayer):
# List all names of layers that may contain adapter weights # All names of layers that may contain (trainable) adapter weights
adapter_layer_names = ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"] adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B")
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")
def __init__(self, in_features: int, out_features: int, **kwargs): def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.base_layer = base_layer
self.r = {} self.r = {}
self.lora_alpha = {} self.lora_alpha = {}
self.scaling = {} self.scaling = {}
@ -42,21 +46,26 @@ class LoraLayer(BaseTunerLayer):
# Mark the weight as unmerged # Mark the weight as unmerged
self._disable_adapters = False self._disable_adapters = False
self.merged_adapters = [] self.merged_adapters = []
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv2d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
elif isinstance(base_layer, Conv1D):
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.kwargs = kwargs
def _init_empty_weights(self, cls, *args, **kwargs) -> None:
# A helper method that allows to initialize the layer of the given class without spending time to initialize the
# model weights. The implementation is inspired by
# https://pytorch.org/docs/stable/generated/torch.nn.utils.skip_init.html but this function cannot be used
# directly.
# Instead of this approach, it would be possible to bypass the __init__ of the class but that runs the risk of
# omitting important logic inside that __init__.
kwargs = kwargs.copy()
final_device = kwargs.pop("device", "cpu")
cls.__init__(self, *args, device="meta", **kwargs)
self.to_empty(device=final_device)
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
if r <= 0: if r <= 0:
@ -77,7 +86,7 @@ class LoraLayer(BaseTunerLayer):
if init_lora_weights: if init_lora_weights:
self.reset_lora_parameters(adapter_name) self.reset_lora_parameters(adapter_name)
weight = getattr(self, "weight", None) weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None: if weight is not None:
# the layer is already completely initialized, this is an update # the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex: if weight.dtype.is_floating_point or weight.dtype.is_complex:
@ -98,20 +107,22 @@ class LoraLayer(BaseTunerLayer):
self.lora_dropout[adapter_name] = lora_dropout_layer self.lora_dropout[adapter_name] = lora_dropout_layer
# Actual trainable parameters # Actual trainable parameters
base_layer = self.get_base_layer()
if r > 0: if r > 0:
kernel_size = self.kwargs["kernel_size"] kernel_size = base_layer.kernel_size
stride = self.kwargs["stride"] stride = base_layer.stride
padding = self.kwargs["padding"] padding = base_layer.padding
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
self.scaling[adapter_name] = lora_alpha / r self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights: if init_lora_weights:
self.reset_lora_parameters(adapter_name) self.reset_lora_parameters(adapter_name)
weight = getattr(self, "weight", None) weight = getattr(base_layer, "weight", None)
if weight is not None: if weight is not None:
# the layer is already completely initialized, this is an update # the layer is already completely initialized, this is an update
self.to(self.weight.device, dtype=weight.dtype) self.to(base_layer.weight.device, dtype=weight.dtype)
self.set_adapter(self.active_adapters)
def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
if r <= 0: if r <= 0:
@ -134,10 +145,12 @@ class LoraLayer(BaseTunerLayer):
if init_lora_weights: if init_lora_weights:
self.reset_lora_parameters(adapter_name) self.reset_lora_parameters(adapter_name)
weight = getattr(self, "weight", None) base_layer = self.get_base_layer()
weight = getattr(base_layer, "weight", None)
if weight is not None: if weight is not None:
# the layer is already completely initialized, this is an update # the layer is already completely initialized, this is an update
self.to(self.weight.device, dtype=weight.dtype) self.to(base_layer.weight.device, dtype=weight.dtype)
self.set_adapter(self.active_adapters)
def reset_lora_parameters(self, adapter_name): def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys(): if adapter_name in self.lora_A.keys():
@ -186,37 +199,29 @@ class LoraLayer(BaseTunerLayer):
# ------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------
class Linear(nn.Linear, LoraLayer): class Linear(nn.Module, LoraLayer):
# Lora implemented in a dense layer # Lora implemented in a dense layer
def __init__( def __init__(
self, self,
base_layer,
adapter_name: str, adapter_name: str,
in_features: int,
out_features: int,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_target_conv_1d_layer: bool = False, is_target_conv_1d_layer: bool = False,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
init_lora_weights = kwargs.pop("init_lora_weights", True) super().__init__()
# this gets the init from nn.Linear's super perspective, i.e. LoraLayer.__init__(self, base_layer)
# nn.Module.__init__, which should always be called
super(nn.Linear, self).__init__()
# Note that we don't use self._init_empty_weights() for Linear because it is a bit slower and the benefit of
# added robustness is not big enough for Linear.
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.fan_in_fan_out = fan_in_fan_out self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.is_target_conv_1d_layer = is_target_conv_1d_layer self.is_target_conv_1d_layer = is_target_conv_1d_layer
self.set_adapter(adapter_name)
def merge(self, safe_merge: bool = False) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights into the base weights Merge the active adapter weights into the base weights
@ -225,18 +230,26 @@ class Linear(nn.Linear, LoraLayer):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters:
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys(): if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge: if safe_merge:
# Note that safe_merge will be slower than the normal merge # Note that safe_merge will be slower than the normal merge
# because of the copy operation. # because of the copy operation.
orig_weights = self.weight.data.clone() orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter) orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all(): if not torch.isfinite(orig_weights).all():
@ -244,9 +257,9 @@ class Linear(nn.Linear, LoraLayer):
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.weight.data = orig_weights base_layer.weight.data = orig_weights
else: else:
self.weight.data += self.get_delta_weight(active_adapter) base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
def unmerge(self) -> None: def unmerge(self) -> None:
@ -256,7 +269,7 @@ class Linear(nn.Linear, LoraLayer):
while len(self.merged_adapters) > 0: while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop() active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys(): if active_adapter in self.lora_A.keys():
self.weight.data -= self.get_delta_weight(active_adapter) self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def get_delta_weight(self, adapter) -> torch.Tensor: def get_delta_weight(self, adapter) -> torch.Tensor:
""" """
@ -292,20 +305,17 @@ class Linear(nn.Linear, LoraLayer):
return output_tensor return output_tensor
def _linear(self, input: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
previous_dtype = x.dtype previous_dtype = x.dtype
if self.disable_adapters: if self.disable_adapters:
if self.merged: if self.merged:
self.unmerge() self.unmerge()
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
elif self.merged: elif self.merged:
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
else: else:
result = self._linear(x) result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys(): if active_adapter not in self.lora_A.keys():
continue continue
@ -319,26 +329,30 @@ class Linear(nn.Linear, LoraLayer):
result = result.to(previous_dtype) result = result.to(previous_dtype)
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
class Embedding(nn.Embedding, LoraLayer):
class Embedding(nn.Module, LoraLayer):
# LoRA implemented in a Embedding layer # LoRA implemented in a Embedding layer
def __init__( def __init__(
self, self,
base_layer: nn.Module,
adapter_name: str, adapter_name: str,
num_embeddings: int,
embedding_dim: int,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
init_lora_weights = kwargs.pop("init_lora_weights", True) super().__init__()
self._init_empty_weights(nn.Embedding, num_embeddings, embedding_dim, **kwargs) LoraLayer.__init__(self, base_layer)
LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim)
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def merge(self, safe_merge: bool = False) -> None: self._active_adapter = adapter_name
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights into the base weights Merge the active adapter weights into the base weights
@ -347,18 +361,26 @@ class Embedding(nn.Embedding, LoraLayer):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters:
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter in self.lora_embedding_A.keys(): if active_adapter in self.lora_embedding_A.keys():
base_layer = self.get_base_layer()
if safe_merge: if safe_merge:
# Note that safe_merge will be slower than the normal merge # Note that safe_merge will be slower than the normal merge
# because of the copy operation. # because of the copy operation.
orig_weights = self.weight.data.copy() orig_weights = base_layer.weight.data.copy()
orig_weights += self.get_delta_weight(active_adapter) orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all(): if not torch.isfinite(orig_weights).all():
@ -366,9 +388,9 @@ class Embedding(nn.Embedding, LoraLayer):
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.weight.data = orig_weights base_layer.weight.data = orig_weights
else: else:
self.weight.data += self.get_delta_weight(active_adapter) base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
def unmerge(self) -> None: def unmerge(self) -> None:
@ -378,7 +400,7 @@ class Embedding(nn.Embedding, LoraLayer):
while len(self.merged_adapters) > 0: while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop() active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_embedding_A.keys(): if active_adapter in self.lora_embedding_A.keys():
self.weight.data -= self.get_delta_weight(active_adapter) self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def get_delta_weight(self, adapter) -> torch.Tensor: def get_delta_weight(self, adapter) -> torch.Tensor:
""" """
@ -414,28 +436,28 @@ class Embedding(nn.Embedding, LoraLayer):
return output_tensor return output_tensor
def _embed(self, input: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
weight = self.weight if weight is None else weight base_layer = self.get_base_layer()
return F.embedding( return F.embedding(
input, input,
weight, weight,
padding_idx=self.padding_idx, padding_idx=base_layer.padding_idx,
max_norm=self.max_norm, max_norm=base_layer.max_norm,
norm_type=self.norm_type, norm_type=base_layer.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq, scale_grad_by_freq=base_layer.scale_grad_by_freq,
sparse=self.sparse, sparse=base_layer.sparse,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# TODO: no dtype conversion here, unlike in Linear, is that correct? # TODO: no dtype conversion here, unlike in Linear, is that correct?
if self.disable_adapters: if self.disable_adapters:
if self.merged: if self.merged:
self.unmerge() self.unmerge()
result = self._embed(x) result = self.base_layer(x, *args, **kwargs)
elif self.merged: elif self.merged:
result = self._embed(x) result = self.base_layer(x, *args, **kwargs)
else: else:
result = self._embed(x) result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A: if active_adapter not in self.lora_embedding_A:
continue continue
@ -447,38 +469,30 @@ class Embedding(nn.Embedding, LoraLayer):
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
class Conv2d(nn.Conv2d, LoraLayer):
class Conv2d(nn.Module, LoraLayer):
# Lora implemented in a conv2d layer # Lora implemented in a conv2d layer
def __init__( def __init__(
self, self,
base_layer: nn.Module,
adapter_name: str, adapter_name: str,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
init_lora_weights = kwargs.pop("init_lora_weights", True) super().__init__()
self._init_empty_weights(nn.Conv2d, in_channels, out_channels, kernel_size, stride=stride, padding=padding) LoraLayer.__init__(self, base_layer)
LoraLayer.__init__(
self,
in_features=in_channels,
out_features=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self._active_adapter = adapter_name
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
def merge(self, safe_merge: bool = False) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
""" """
Merge the active adapter weights inside the base weights Merge the active adapter weights inside the base weights
@ -487,27 +501,35 @@ class Conv2d(nn.Conv2d, LoraLayer):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`. NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
""" """
if self.merged: if self.merged:
warnings.warn( warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters:
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys(): if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge: if safe_merge:
# Note that safe_merge will be slower than the normal merge # Note that safe_merge will be slower than the normal merge
# because of the copy operation. # because of the copy operation.
orig_weights = self.weight.data.copy() orig_weights = base_layer.weight.data.copy()
orig_weights += self.get_delta_weight(active_adapter) orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all(): if not torch.isfinite(orig_weights).all():
raise ValueError( raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
) )
self.weight.data = orig_weights base_layer.weight.data = orig_weights
else: else:
self.weight.data += self.get_delta_weight(active_adapter) base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
def unmerge(self) -> None: def unmerge(self) -> None:
@ -517,7 +539,7 @@ class Conv2d(nn.Conv2d, LoraLayer):
while len(self.merged_adapters) > 0: while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop() active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys(): if active_adapter in self.lora_A.keys():
self.weight.data -= self.get_delta_weight(active_adapter) self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def get_delta_weight(self, adapter) -> torch.Tensor: def get_delta_weight(self, adapter) -> torch.Tensor:
""" """
@ -543,7 +565,7 @@ class Conv2d(nn.Conv2d, LoraLayer):
weight_B = weight_B.float() weight_B = weight_B.float()
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
if self.weight.size()[2:4] == (1, 1): if self.get_base_layer().weight.size()[2:4] == (1, 1):
# conv2d 1x1 # conv2d 1x1
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
3 3
@ -567,28 +589,17 @@ class Conv2d(nn.Conv2d, LoraLayer):
return output_tensor return output_tensor
def _conv2d(self, input: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return F.conv2d(
input,
self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
previous_dtype = x.dtype previous_dtype = x.dtype
if self.disable_adapters: if self.disable_adapters:
if self.merged: if self.merged:
self.unmerge() self.unmerge()
result = self._conv2d(x) result = self.base_layer(x, *args, **kwargs)
elif self.merged: elif self.merged:
result = self._conv2d(x) result = self.base_layer(x, *args, **kwargs)
else: else:
result = self._conv2d(x) result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys(): if active_adapter not in self.lora_A.keys():
continue continue
@ -601,3 +612,7 @@ class Conv2d(nn.Conv2d, LoraLayer):
result = result.to(previous_dtype) result = result.to(previous_dtype)
return result return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep

View File

@ -19,9 +19,9 @@ from dataclasses import asdict, replace
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from itertools import chain from itertools import chain
from typing import List, Optional
import torch import torch
from torch import nn
from tqdm import tqdm from tqdm import tqdm
from transformers.pytorch_utils import Conv1D from transformers.pytorch_utils import Conv1D
@ -107,6 +107,8 @@ class LoraModel(BaseTuner):
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model. - **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
""" """
prefix: str = "lora_"
def __init__(self, model, config, adapter_name) -> None: def __init__(self, model, config, adapter_name) -> None:
super().__init__(model, config, adapter_name) super().__init__(model, config, adapter_name)
@ -164,7 +166,7 @@ class LoraModel(BaseTuner):
kwargs["gptq_quantization_config"] = quantization_config kwargs["gptq_quantization_config"] = quantization_config
# TODO: better deal with that # TODO: better deal with that
if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): if isinstance(target, Conv2d):
target.update_layer_conv2d( target.update_layer_conv2d(
adapter_name, adapter_name,
r, r,
@ -172,7 +174,7 @@ class LoraModel(BaseTuner):
lora_config.lora_dropout, lora_config.lora_dropout,
lora_config.init_lora_weights, lora_config.init_lora_weights,
) )
elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding): elif isinstance(target, Embedding):
target.update_layer_embedding( target.update_layer_embedding(
adapter_name, adapter_name,
r, r,
@ -180,8 +182,7 @@ class LoraModel(BaseTuner):
lora_config.lora_dropout, lora_config.lora_dropout,
lora_config.init_lora_weights, lora_config.init_lora_weights,
) )
elif isinstance(target, Linear):
elif isinstance(target, LoraLayer):
target.update_layer( target.update_layer(
adapter_name, adapter_name,
r, r,
@ -196,8 +197,7 @@ class LoraModel(BaseTuner):
new_module.requires_grad_(False) new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target) self._replace_module(parent, target_name, new_module, target)
@staticmethod def _replace_module(self, parent, child_name, new_module, child):
def _replace_module(parent, child_name, new_module, child):
setattr(parent, child_name, new_module) setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by # It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable # _mark_only_adapters_as_trainable
@ -205,10 +205,7 @@ class LoraModel(BaseTuner):
# child layer wraps the original module, unpack it # child layer wraps the original module, unpack it
if hasattr(child, "base_layer"): if hasattr(child, "base_layer"):
child = child.base_layer child = child.base_layer
elif hasattr(child, "quant_linear_module"):
child = child.quant_linear_module
# TODO: layers with base_layer don't need the weight to be copied, as they have a reference already
if not hasattr(new_module, "base_layer"): if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight new_module.weight = child.weight
if hasattr(child, "bias"): if hasattr(child, "bias"):
@ -223,14 +220,13 @@ class LoraModel(BaseTuner):
# dispatch to correct device # dispatch to correct device
for name, module in new_module.named_modules(): for name, module in new_module.named_modules():
if "lora_" in name: if (self.prefix in name) or ("ranknum" in name):
module.to(child.weight.device) weight = child.qweight if hasattr(child, "qweight") else child.weight
if "ranknum" in name: module.to(weight.device)
module.to(child.weight.device)
def _mark_only_adapters_as_trainable(self) -> None: def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters(): for n, p in self.model.named_parameters():
if "lora_" not in n: if self.prefix not in n:
p.requires_grad = False p.requires_grad = False
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
@ -256,9 +252,13 @@ class LoraModel(BaseTuner):
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
bias = kwargs.pop("bias", False)
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy() eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update( eightbit_kwargs.update(
{ {
@ -268,8 +268,8 @@ class LoraModel(BaseTuner):
"index": target.index, "index": target.index,
} }
) )
new_module = Linear8bitLt(adapter_name, target, **eightbit_kwargs) new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy() fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update( fourbit_kwargs.update(
{ {
@ -278,47 +278,37 @@ class LoraModel(BaseTuner):
"quant_type": target.weight.quant_type, "quant_type": target.weight.quant_type,
} }
) )
new_module = Linear4bit(adapter_name, target, **fourbit_kwargs) new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
new_module = QuantLinear(adapter_name, target, **kwargs) new_module = QuantLinear(target, adapter_name, **kwargs)
target.weight = target.qweight target.weight = target.qweight
elif isinstance(target, torch.nn.Embedding): elif isinstance(target_base_layer, torch.nn.Embedding):
embedding_kwargs = kwargs.copy() embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None) embedding_kwargs.pop("fan_in_fan_out", None)
in_features, out_features = target.num_embeddings, target.embedding_dim new_module = Embedding(target, adapter_name, **embedding_kwargs)
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) elif isinstance(target_base_layer, torch.nn.Conv2d):
elif isinstance(target, torch.nn.Conv2d): new_module = Conv2d(target, adapter_name, **kwargs)
out_channels, in_channels = target.weight.size()[:2] elif isinstance(target_base_layer, torch.nn.Linear):
kernel_size = target.weight.size()[2:] if kwargs["fan_in_fan_out"]:
stride = target.stride warnings.warn(
padding = target.padding "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs) "Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
new_module = Linear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
else: else:
if isinstance(target, torch.nn.Linear): raise ValueError(
in_features, out_features = target.in_features, target.out_features f"Target module {target} is not supported. Currently, only the following modules are supported: "
if kwargs["fan_in_fan_out"]: "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
warnings.warn( )
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
kwargs["is_target_conv_1d_layer"] = True
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
)
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
return new_module return new_module
@ -376,65 +366,31 @@ class LoraModel(BaseTuner):
) )
return peft_config return peft_config
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False, safe_merge: bool = False): def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[List[str]] = None,
):
if merge: if merge:
if getattr(self.model, "quantization_method", None) == "gptq": if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized") raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model" desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc): for key in tqdm(key_list, disable=not progressbar, desc=desc):
try: try:
parent, target, target_name = _get_submodules(self.model, key) parent, target, target_name = _get_submodules(self.model, key)
except AttributeError: except AttributeError:
continue continue
if isinstance(target, LoraLayer):
if isinstance(target, nn.Embedding):
new_module = torch.nn.Embedding(target.in_features, target.out_features)
elif isinstance(target, nn.Conv2d):
new_module = torch.nn.Conv2d(
target.in_channels,
target.out_channels,
kernel_size=target.kernel_size,
stride=target.stride,
padding=target.padding,
dilation=target.dilation,
)
elif is_bnb_available() and isinstance(target, Linear8bitLt):
bias = target.base_layer.bias is not None
new_module = bnb.nn.Linear8bitLt(
target.in_features,
target.out_features,
bias=bias,
has_fp16_weights=target.base_layer.state.has_fp16_weights,
memory_efficient_backward=target.base_layer.state.memory_efficient_backward,
threshold=target.base_layer.state.threshold,
index=target.base_layer.index,
device=target.base_layer.weight.device,
)
elif is_bnb_4bit_available() and isinstance(target, Linear4bit):
bias = target.base_layer.bias is not None
new_module = bnb.nn.Linear4bit(
target.in_features,
target.out_features,
bias=bias,
compute_dtype=target.base_layer.compute_dtype,
compress_statistics=target.base_layer.weight.compress_statistics,
quant_type=target.base_layer.weight.quant_type,
device=target.base_layer.weight.device,
)
else:
bias = target.bias is not None
if getattr(target, "is_target_conv_1d_layer", False):
new_module = Conv1D(target.out_features, target.in_features)
else:
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
if merge:
target.merge(safe_merge=safe_merge)
self._replace_module(parent, target_name, new_module, target)
# save any additional trainable modules part of `modules_to_save` if hasattr(target, "base_layer"):
if isinstance(target, ModulesToSaveWrapper): if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter]) setattr(parent, target_name, target.modules_to_save[target.active_adapter])
return self.model return self.model
@ -536,7 +492,7 @@ class LoraModel(BaseTuner):
# Do we really need that? # Do we really need that?
_freeze_adapter(self.model, adapter_name) _freeze_adapter(self.model, adapter_name)
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list: for key in key_list:
_, target, _ = _get_submodules(self.model, key) _, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer): if isinstance(target, LoraLayer):
@ -660,32 +616,20 @@ class LoraModel(BaseTuner):
raise ValueError(f"Adapter {adapter_name} does not exist") raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name] del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list: for key in key_list:
_, target, _ = _get_submodules(self.model, key) _, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer): if isinstance(target, LoraLayer):
for attr in [ target.delete_adapter(adapter_name)
"r", if new_adapter is None:
"lora_alpha", new_adapter = target.active_adapters[:]
"scaling",
"lora_A",
"lora_B",
"lora_embedding_A",
"lora_embedding_B",
"lora_dropout",
]:
if adapter_name in getattr(target, attr):
getattr(target, attr).pop(adapter_name)
if adapter_name in target.active_adapters:
resetting_active_adapter = (
list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default"
)
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
)
target.set_adapter(resetting_active_adapter)
def merge_and_unload(self, progressbar: bool = False, safe_merge: bool = False): self.active_adapter = new_adapter or []
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
):
r""" r"""
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
as a standalone model. as a standalone model.
@ -696,7 +640,9 @@ class LoraModel(BaseTuner):
safe_merge (`bool`): safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
Example: Example:
```py ```py
@ -709,7 +655,9 @@ class LoraModel(BaseTuner):
>>> merged_model = model.merge_and_unload() >>> merged_model = model.merge_and_unload()
``` ```
""" """
return self._unload_and_optionally_merge(progressbar=progressbar, safe_merge=safe_merge) return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self): def unload(self):
""" """

View File

@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from typing import Any, Dict, List, Optional, Set, Type, Union
from typing import Dict, Optional, Set, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -58,12 +56,15 @@ class LycorisConfig(PeftConfig):
) )
class LycorisLayer(BaseTunerLayer, nn.Module): class LycorisLayer(BaseTunerLayer):
r""" r"""
A base layer for LyCORIS like adapters A base layer for LyCORIS like adapters
""" """
# adapter_layer_names needs to be defined on the child class
other_param_names = ("r", "alpha", "scaling", "rank_dropout", "module_dropout")
def __init__(self): def __init__(self, base_layer: nn.Module) -> None:
self.base_layer = base_layer
self.r = {} self.r = {}
self.alpha = {} self.alpha = {}
self.scaling = {} self.scaling = {}
@ -91,56 +92,44 @@ class LycorisLayer(BaseTunerLayer, nn.Module):
cls.__init__(self, *args, device="meta", **kwargs) cls.__init__(self, *args, device="meta", **kwargs)
self.to_empty(device=final_device) self.to_empty(device=final_device)
def _op(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod @abstractmethod
def create_adapter_parameters(self, adapter_name: str, r: int, **kwargs): def create_adapter_parameters(self, adapter_name: str, r: int, **kwargs):
... ...
def forward(self, x: torch.Tensor) -> torch.Tensor: # TODO: refactor LoRA to use the same approach
previous_dtype = x.dtype @abstractmethod
def _get_delta_activations(self, adapter_name: str, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.disable_adapters: """Activations added on top of the base layer output (i.e. after the base layer forward pass)"""
if self.merged:
self.unmerge()
result = self._op(x, self.weight)
elif self.merged:
result = self._op(x, self.weight)
else:
# Get base weights
weight = self.weight.data
# Execute all the adapters
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
module_dropout = self.module_dropout[active_adapter]
# Modify current execution weights
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
weight = weight + self.get_delta_weight(active_adapter)
# Perform actual operation
result = self._op(x, weight)
result = result.to(previous_dtype)
return result
@abstractmethod @abstractmethod
def get_delta_weight(self, adapter_name: str) -> torch.Tensor: def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
... ...
def merge(self) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
if self.merged: if self.merged:
warnings.warn( warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}." f"You are now additionally merging {','.join(self.active_adapters)}."
) )
for active_adapter in self.active_adapters: if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter in self._available_adapters: if active_adapter in self._available_adapters:
self.weight.data += self.get_delta_weight(active_adapter) base_layer = self.get_base_layer()
if safe_merge:
orig_weights = base_layer.weight.data
orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter) self.merged_adapters.append(active_adapter)
@abstractmethod @abstractmethod
@ -170,7 +159,7 @@ class LycorisLayer(BaseTunerLayer, nn.Module):
while len(self.merged_adapters) > 0: while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop() active_adapter = self.merged_adapters.pop()
if active_adapter in self._available_adapters: if active_adapter in self._available_adapters:
self.weight.data -= self.get_delta_weight(active_adapter) self.base_layer.weight.data -= self.get_delta_weight(active_adapter)
def unscale_layer(self, scale=None) -> None: def unscale_layer(self, scale=None) -> None:
for active_adapter in self.active_adapters: for active_adapter in self.active_adapters:
@ -209,6 +198,7 @@ class LycorisTuner(BaseTuner):
def _check_target_module_exists(config, key): def _check_target_module_exists(config, key):
return check_target_module_exists(config, key) return check_target_module_exists(config, key)
@abstractmethod
def _create_and_replace( def _create_and_replace(
self, self,
config: LycorisConfig, config: LycorisConfig,
@ -219,68 +209,47 @@ class LycorisTuner(BaseTuner):
current_key, current_key,
**optional_kwargs, **optional_kwargs,
): ):
""" ...
A private method to create and replace the target module with the adapter module.
"""
# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name)
kwargs = config.to_dict()
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
if isinstance(target, LycorisLayer):
target.update_layer(adapter_name, **kwargs)
else:
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)
@classmethod @classmethod
def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn.Module, **kwargs) -> LycorisLayer: def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn.Module, **kwargs) -> LycorisLayer:
# Find corresponding subtype of provided target module # Find corresponding subtype of provided target module
new_module_cls = None new_module_cls = None
for subtype, target_cls in cls.layers_mapping.items(): for subtype, target_cls in cls.layers_mapping.items():
if isinstance(target, subtype): if (
hasattr(target, "base_layer")
and isinstance(target.get_base_layer(), subtype)
and isinstance(target, BaseTunerLayer)
):
# nested tuner layers are allowed
new_module_cls = target_cls
break
elif isinstance(target, subtype):
new_module_cls = target_cls new_module_cls = target_cls
break break
# We didn't find corresponding type, so adapter for this layer is not supported # We didn't find corresponding type, so adapter for this layer is not supported
if new_module_cls is None: if new_module_cls is None:
supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys())
raise ValueError( raise ValueError(
f"Target module not found, currently only adapters for {', '.join([x.__name__ for x in cls.modules_mapping.keys()])} are supported" f"Target module of type {type(target)} not supported, "
f"currently only adapters for {supported_modules} are supported"
) )
if isinstance(target, torch.nn.Conv2d): if isinstance(target, BaseTunerLayer):
new_module = new_module_cls( target_base_layer = target.get_base_layer()
target.in_channels,
target.out_channels,
target.weight.size()[2:],
stride=target.stride,
padding=target.padding,
dilation=target.dilation,
groups=target.groups,
bias=target.bias is not None,
padding_mode=target.padding_mode,
device=target.weight.device,
dtype=target.weight.dtype,
adapter_name=adapter_name,
**kwargs,
)
elif isinstance(target, torch.nn.Linear):
new_module = new_module_cls(
target.in_features,
target.out_features,
bias=target.bias is not None,
device=target.weight.device,
dtype=target.weight.dtype,
adapter_name=adapter_name,
**kwargs,
)
else: else:
target_base_layer = target
if isinstance(target_base_layer, torch.nn.Conv2d):
new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs)
else:
supported_modules = ", ".join(layer.__name__ for layer in cls.layers_mapping.keys())
raise ValueError( raise ValueError(
"Target module not found, currently only adapters for nn.Linear and nn.Conv2d are supported" f"Target module of type {type(target)} not supported, "
f"currently only adapters for {supported_modules} are supported"
) )
return new_module return new_module
@ -300,12 +269,17 @@ class LycorisTuner(BaseTuner):
setattr(parent, child_name, new_module) setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by # It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable # _mark_only_adapters_as_trainable
new_module.weight = child.weight
if hasattr(child, "bias"): if not hasattr(new_module, "base_layer"):
new_module.bias = child.bias new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias
if getattr(child, "state", None) is not None: if getattr(child, "state", None) is not None:
new_module.state = child.state if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device) new_module.to(child.weight.device)
# dispatch to correct device # dispatch to correct device
@ -318,46 +292,31 @@ class LycorisTuner(BaseTuner):
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled) module.enable_adapters(enabled)
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False): def _unload_and_optionally_merge(
self,
merge: bool = True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[List[str]] = None,
):
if merge: if merge:
if getattr(self.model, "quantization_method", None) == "gptq": if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LOHA layers when the model is gptq quantized") raise ValueError("Cannot merge LOHA layers when the model is gptq quantized")
key_list = [key for key, _ in self.model.named_modules() if "hada" not in key] key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model" desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc): for key in tqdm(key_list, disable=not progressbar, desc=desc):
try: try:
parent, target, target_name = _get_submodules(self.model, key) parent, target, target_name = _get_submodules(self.model, key)
except AttributeError: except AttributeError:
continue continue
if isinstance(target, LycorisLayer):
if isinstance(target, nn.Conv2d):
new_module = torch.nn.Conv2d(
target.in_channels,
target.out_channels,
kernel_size=target.kernel_size,
stride=target.stride,
padding=target.padding,
dilation=target.dilation,
)
elif isinstance(target, nn.Linear):
bias = target.bias is not None
new_module = torch.nn.Linear(
target.in_features,
target.out_features,
bias=bias,
device=target.weight.device,
)
else:
raise ValueError(
"Cannot convert current module to torch module, currently only adapters for nn.Linear and nn.Conv2d are supported"
)
if merge:
target.merge()
self._replace_module(parent, target_name, new_module, target)
# save any additional trainable modules part of `modules_to_save` if hasattr(target, "base_layer"):
if isinstance(target, ModulesToSaveWrapper): if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter]) setattr(parent, target_name, target.modules_to_save[target.active_adapter])
return self.model return self.model
@ -368,8 +327,34 @@ class LycorisTuner(BaseTuner):
def disable_adapter_layers(self): def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def merge_and_unload(self, progressbar: bool = False): def merge_and_unload(
return self._unload_and_optionally_merge(progressbar=progressbar) self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
):
r"""
This method merges the adapter layers into the base model. This is needed if someone wants to use the base
model as a standalone model.
Args:
progressbar (`bool`):
whether to show a progressbar indicating the unload and merge process
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self):
"""
Gets back the base model by removing all the lora modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name):
for module in self.model.modules(): for module in self.model.modules():
@ -391,17 +376,12 @@ class LycorisTuner(BaseTuner):
del self.peft_config[adapter_name] del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list: for key in key_list:
_, target, _ = _get_submodules(self.model, key) _, target, _ = _get_submodules(self.model, key)
if isinstance(target, LycorisLayer): if isinstance(target, LycorisLayer):
for attr in target.adapter_layer_names: target.delete_adapter(adapter_name)
if adapter_name in getattr(target, attr): if new_adapter is None:
getattr(target, attr).pop(adapter_name) new_adapter = target.active_adapters[:]
if adapter_name in target.active_adapters:
resetting_active_adapter = ( self.active_adapter = new_adapter or []
list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default"
)
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
)
target.set_adapter(resetting_active_adapter)

View File

@ -104,7 +104,7 @@ class PromptEncoder(torch.nn.Module):
encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers
if config.encoder_num_layers != encoder_num_layers_default: if config.encoder_num_layers != encoder_num_layers_default:
warnings.warn( warnings.warn(
f"for {self.encoder_type}, the argument `encoder_num_layers` is ignored. " f"for {self.encoder_type.value}, the argument `encoder_num_layers` is ignored. "
f"Exactly {encoder_num_layers_default} MLP layers are used." f"Exactly {encoder_num_layers_default} MLP layers are used."
) )
layers = [ layers = [

View File

@ -37,6 +37,9 @@ class PromptTuningConfig(PromptLearningConfig):
The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`.
tokenizer_name_or_path (`str`, *optional*): tokenizer_name_or_path (`str`, *optional*):
The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`. The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`.
tokenizer_kwargs (`dict`, *optional*):
The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if `prompt_tuning_init` is
`TEXT`.
""" """
prompt_tuning_init: Union[PromptTuningInit, str] = field( prompt_tuning_init: Union[PromptTuningInit, str] = field(
@ -56,5 +59,20 @@ class PromptTuningConfig(PromptLearningConfig):
}, },
) )
tokenizer_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": (
"The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if prompt_tuning_init is "
"`TEXT`"
),
},
)
def __post_init__(self): def __post_init__(self):
self.peft_type = PeftType.PROMPT_TUNING self.peft_type = PeftType.PROMPT_TUNING
if self.tokenizer_kwargs and (self.prompt_tuning_init != PromptTuningInit.TEXT):
raise ValueError(
f"tokenizer_kwargs only valid when using prompt_tuning_init='{PromptTuningInit.TEXT.value}'."
)

View File

@ -66,7 +66,8 @@ class PromptEmbedding(torch.nn.Module):
if config.prompt_tuning_init == PromptTuningInit.TEXT: if config.prompt_tuning_init == PromptTuningInit.TEXT:
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path) tokenizer_kwargs = config.tokenizer_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, **tokenizer_kwargs)
init_text = config.prompt_tuning_init_text init_text = config.prompt_tuning_init_text
init_token_ids = tokenizer(init_text)["input_ids"] init_token_ids = tokenizer(init_text)["input_ids"]
# Trim or iterate until num_text_tokens matches total_virtual_tokens # Trim or iterate until num_text_tokens matches total_virtual_tokens
@ -77,8 +78,9 @@ class PromptEmbedding(torch.nn.Module):
num_reps = math.ceil(total_virtual_tokens / num_text_tokens) num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
init_token_ids = init_token_ids * num_reps init_token_ids = init_token_ids * num_reps
init_token_ids = init_token_ids[:total_virtual_tokens] init_token_ids = init_token_ids[:total_virtual_tokens]
init_token_ids = torch.LongTensor(init_token_ids).to(word_embeddings.weight.device)
word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone() word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32) word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights) self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

View File

@ -16,15 +16,17 @@ from __future__ import annotations
import logging import logging
import re import re
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Union from typing import Any, Union
import torch
from torch import nn from torch import nn
from peft.utils import COMMON_LAYERS_PATTERN from peft.utils import COMMON_LAYERS_PATTERN
from ..config import PeftConfig from ..config import PeftConfig
from ..utils import _get_submodules from ..utils import ModulesToSaveWrapper, _get_submodules
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -210,6 +212,9 @@ class BaseTuner(nn.Module, ABC):
is_target_modules_in_base_model = False is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()] key_list = [key for key, _ in model.named_modules()]
_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
_has_modules_to_save = False
model_config = getattr(model, "config", {"model_type": "custom"}) model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"): if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict() model_config = model_config.to_dict()
@ -217,6 +222,22 @@ class BaseTuner(nn.Module, ABC):
peft_config = self._prepare_adapter_config(peft_config, model_config) peft_config = self._prepare_adapter_config(peft_config, model_config)
for key in key_list: for key in key_list:
# Check for modules_to_save in case
if _check_for_modules_to_save and any(
key.endswith(f"{module_to_save}") for module_to_save in peft_config.modules_to_save
):
# Optionally set the modules to save
parent, target, target_name = _get_submodules(model, key)
if not isinstance(target, ModulesToSaveWrapper):
new_module = ModulesToSaveWrapper(target, adapter_name)
setattr(parent, target_name, new_module)
else:
target.update(adapter_name)
_has_modules_to_save = True
continue
if not self._check_target_module_exists(peft_config, key): if not self._check_target_module_exists(peft_config, key):
continue continue
@ -243,6 +264,12 @@ class BaseTuner(nn.Module, ABC):
if adapter_name in n: if adapter_name in n:
p.requires_grad = False p.requires_grad = False
if _has_modules_to_save:
if not hasattr(model, "modules_to_save"):
model.modules_to_save = set(peft_config.modules_to_save)
else:
model.modules_to_save.update(set(peft_config.modules_to_save))
def merge_adapter(self): def merge_adapter(self):
""" """
This method merges the LoRa layers into the base model. This method merges the LoRa layers into the base model.
@ -272,8 +299,10 @@ class BaseTunerLayer(ABC):
""" """
active_adapter = None active_adapter = None
# List all names of layers that may contain adapter weights # All names of layers that may contain adapter (trainable) weights
adapter_layer_names: list[str] = [] adapter_layer_names: tuple[str] = ()
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str] = ()
# indicates whether all adapters should be disabled # indicates whether all adapters should be disabled
_disable_adapters: bool = False _disable_adapters: bool = False
@ -284,6 +313,34 @@ class BaseTunerLayer(ABC):
# List all merged adapters # List all merged adapters
merged_adapters: list[str] = [] merged_adapters: list[str] = []
def get_base_layer(self) -> nn.Module:
"""
(Recursively) get the base_layer.
This is necessary for the case that the tuner layer wraps another tuner layer.
"""
base_layer = self
while hasattr(base_layer, "base_layer"):
base_layer = base_layer.base_layer
return base_layer
@property
def weight(self) -> torch.Tensor:
# This is required for some transformers code, e.g. for T5, weight is accessed as:
# self.wo.weight
# where "wo" is the adapter layer.
# https://github.com/huggingface/transformers/blob/78f6ed6c70b29c1560780e3869a7ad4c6b3d2710/src/transformers
# /models/t5/modeling_t5.py#L292
base_layer = self.get_base_layer()
if hasattr(base_layer, "qweight"):
# QuantLinear
weight = base_layer.qweight
else:
# Other layers
weight = base_layer.weight
return weight
def merge(self, *args) -> None: def merge(self, *args) -> None:
raise NotImplementedError raise NotImplementedError
@ -351,6 +408,54 @@ class BaseTunerLayer(ABC):
self._active_adapter = adapter_names self._active_adapter = adapter_names
def _all_available_adapter_names(self) -> list[str]:
"""Return a sorted list of all available adapter names"""
adapter_names = set()
for name in self.adapter_layer_names + self.other_param_names:
# we check each possible attribute and if it's a dict or ModuleDict, we assume that the keys are the adapter
# names
attr = getattr(self, name)
if hasattr(attr, "keys"):
adapter_names.update(attr.keys())
return sorted(adapter_names)
def delete_adapter(self, adapter_name: str) -> None:
"""
Delete an adapter from the layer
This should be called on all adapter layers, or else we will get an inconsistent state.
This method will also set a new active adapter if the deleted adapter was an active adapter. It is important
that the new adapter is chosen in a deterministic way, so that the same adapter is chosen on all layers.
Args:
adapter_name (`str`): The name of the adapter to delete
"""
for attr in self.adapter_layer_names + self.other_param_names:
if adapter_name in getattr(self, attr):
del getattr(self, attr)[adapter_name]
if adapter_name in self.active_adapters:
# choose a new active adapter
active_adapters = self.active_adapters[:]
active_adapters.remove(adapter_name)
if active_adapters:
self.set_adapter(active_adapters)
else:
# no active adapters left, set a new default adapter
# here we get the list of all adapters existing adapter names and choose the first one
remaining_adapters = self._all_available_adapter_names()
if not remaining_adapters:
self.set_adapter([])
else:
new_active_adapter = remaining_adapters[0]
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to "
f"{new_active_adapter}."
)
self.set_adapter(remaining_adapters[0])
def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
"""A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. """A helper method to check if the passed module's key name matches any of the target modules in the adapter_config.

View File

@ -45,6 +45,6 @@ from .other import (
infer_device, infer_device,
get_auto_gptq_quant_linear, get_auto_gptq_quant_linear,
get_quantization_config, get_quantization_config,
id_tensor_storage,
) )
from .hub_utils import hub_file_exists
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights

View File

@ -1,29 +0,0 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub import get_hf_file_metadata, hf_hub_url
from huggingface_hub.utils import EntryNotFoundError
def hub_file_exists(repo_id: str, filename: str, revision: str = None, repo_type: str = None) -> bool:
r"""
Checks if a file exists in a remote Hub repository.
"""
url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision)
try:
get_hf_file_metadata(url)
return True
except EntryNotFoundError:
return False

View File

@ -15,14 +15,15 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from typing import Optional from typing import Optional, Tuple
import accelerate import accelerate
import torch import torch
from accelerate.hooks import add_hook_to_module, remove_hook_from_module from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import is_npu_available, is_xpu_available from accelerate.utils import is_npu_available, is_xpu_available
from safetensors.torch import storage_ptr, storage_size
from ..import_utils import is_auto_gptq_available from ..import_utils import is_auto_gptq_available, is_torch_tpu_available
# Get current device name based on available devices # Get current device name based on available devices
@ -276,8 +277,22 @@ def _set_trainable(model, adapter_name):
def _set_adapter(model, adapter_name): def _set_adapter(model, adapter_name):
def check_adapter_name(adapter_name):
if isinstance(adapter_name, str):
return adapter_name
# adapter_name is a list of str
if len(adapter_name) > 1:
raise ValueError("Only one adapter can be set at a time for modules_to_save")
elif len(adapter_name) == 0:
raise ValueError("Please specify at least one adapter to set")
adapter_name = adapter_name[0]
return adapter_name
for module in model.modules(): for module in model.modules():
if isinstance(module, ModulesToSaveWrapper): if isinstance(module, ModulesToSaveWrapper):
# only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care
adapter_name = check_adapter_name(adapter_name)
module.set_adapter(adapter_name) module.set_adapter(adapter_name)
@ -412,33 +427,57 @@ def get_auto_gptq_quant_linear(gptq_quantization_config):
""" """
Get the right AutoGPTQQuantLinear class based on the quantization config file Get the right AutoGPTQQuantLinear class based on the quantization config file
""" """
if is_auto_gptq_available(): if gptq_quantization_config is not None and is_auto_gptq_available():
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
if gptq_quantization_config is not None: desc_act = gptq_quantization_config.desc_act
desc_act = gptq_quantization_config.desc_act group_size = gptq_quantization_config.group_size
group_size = gptq_quantization_config.group_size bits = gptq_quantization_config.bits
bits = gptq_quantization_config.bits if hasattr(gptq_quantization_config, "use_exllama"):
if hasattr(gptq_quantization_config, "use_exllama"): use_exllama = gptq_quantization_config.use_exllama
use_exllama = gptq_quantization_config.use_exllama else:
else: use_exllama = not gptq_quantization_config.disable_exllama
use_exllama = not gptq_quantization_config.disable_exllama if hasattr(gptq_quantization_config, "exllama_config"):
if hasattr(gptq_quantization_config, "exllama_config"): exllama_version = gptq_quantization_config.exllama_config["version"]
exllama_version = gptq_quantization_config.exllama_config["version"] else:
else: exllama_version = 1
exllama_version = 1 AutoGPTQQuantLinear = dynamically_import_QuantLinear(
AutoGPTQQuantLinear = dynamically_import_QuantLinear( use_triton=False,
use_triton=False, desc_act=desc_act,
desc_act=desc_act, group_size=group_size,
group_size=group_size, bits=bits,
bits=bits, disable_exllama=not (use_exllama and exllama_version == 1),
disable_exllama=not (use_exllama and exllama_version == 1), disable_exllamav2=not (use_exllama and exllama_version == 2),
disable_exllamav2=not (use_exllama and exllama_version == 2), )
) return AutoGPTQQuantLinear
return AutoGPTQQuantLinear
return None return None
def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
"""
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
This method is the exact same copy of
https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added
it here manually to avoid import issue with old versions of transformers.
"""
if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
# device. So the following import is safe:
import torch_xla
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
else:
unique_id = storage_ptr(tensor)
return tensor.device, unique_id, storage_size(tensor)
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"], "t5": ["q", "v"],
"mt5": ["q", "v"], "mt5": ["q", "v"],

View File

@ -16,11 +16,10 @@ import os
from typing import Optional from typing import Optional
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from .hub_utils import hub_file_exists
from .other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device from .other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device
from .peft_types import PeftType from .peft_types import PeftType
@ -194,9 +193,9 @@ def load_peft_weights(model_id: str, device: Optional[str] = None, **hf_hub_down
filename = os.path.join(path, WEIGHTS_NAME) filename = os.path.join(path, WEIGHTS_NAME)
use_safetensors = False use_safetensors = False
else: else:
has_remote_safetensors_file = hub_file_exists( has_remote_safetensors_file = file_exists(
model_id, repo_id=model_id,
SAFETENSORS_WEIGHTS_NAME, filename=SAFETENSORS_WEIGHTS_NAME,
revision=hf_hub_download_kwargs.get("revision", None), revision=hf_hub_download_kwargs.get("revision", None),
repo_type=hf_hub_download_kwargs.get("repo_type", None), repo_type=hf_hub_download_kwargs.get("repo_type", None),
) )

View File

@ -115,6 +115,51 @@ class AdaptionPromptTester(TestCase, PeftCommonTester):
self.assertTrue(dummy_output.requires_grad) self.assertTrue(dummy_output.requires_grad)
def test_save_pretrained_regression(self) -> None:
seed = 420
torch.manual_seed(seed)
model = LlamaForCausalLM(self._create_test_llama_config())
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
model = get_peft_model(model, config)
model = model.to(self.torch_device)
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname, safe_serialization=False)
torch.manual_seed(seed)
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
# check if the state dicts are equal
state_dict = get_peft_model_state_dict(model)
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
# check if same keys
self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys())
# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
self.assertEqual(len(list(state_dict.keys())), 4)
# check if tensors equal
for key in state_dict.keys():
self.assertTrue(
torch.allclose(
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
)
)
# check if `adapter_model.bin` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
# check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
# check if `model.safetensors` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
def test_save_pretrained(self) -> None: def test_save_pretrained(self) -> None:
seed = 420 seed = 420
torch.manual_seed(seed) torch.manual_seed(seed)
@ -149,13 +194,13 @@ class AdaptionPromptTester(TestCase, PeftCommonTester):
) )
# check if `adapter_model.bin` is present # check if `adapter_model.bin` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")))
# check if `adapter_config.json` is present # check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
# check if `pytorch_model.bin` is not present # check if `model.safetensors` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
# check if `config.json` is not present # check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
@ -199,13 +244,13 @@ class AdaptionPromptTester(TestCase, PeftCommonTester):
) )
# check if `adapter_model.bin` is present # check if `adapter_model.bin` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")))
# check if `adapter_config.json` is present # check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
# check if `pytorch_model.bin` is not present # check if `model.safetensors` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
# check if `config.json` is not present # check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

File diff suppressed because it is too large Load Diff

View File

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
from unittest.mock import Mock, call, patch
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AdaLoraConfig from peft import AdaLoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model
from .testing_common import PeftCommonTester, PeftTestConfigManager from .testing_common import PeftCommonTester, PeftTestConfigManager
@ -76,14 +77,77 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_prepare_for_training(model_id, config_cls, config_kwargs) self._test_prepare_for_training(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_prompt_tuning_text_prepare_for_training(self, test_name, model_id, config_cls, config_kwargs):
# Test that prompt tuning works with text init
if config_cls != PromptTuningConfig:
return
config_kwargs = config_kwargs.copy()
config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT
config_kwargs["prompt_tuning_init_text"] = "This is a test prompt."
config_kwargs["tokenizer_name_or_path"] = model_id
self._test_prepare_for_training(model_id, config_cls, config_kwargs)
def test_prompt_tuning_text_tokenizer_kwargs(self):
# Allow users to pass additional arguments to Tokenizer.from_pretrained
# Fix for #1032
mock = Mock()
orig_from_pretrained = AutoTokenizer.from_pretrained
def mock_autotokenizer_from_pretrained(*args, **kwargs):
mock(*args, **kwargs)
return orig_from_pretrained(config.tokenizer_name_or_path)
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
config = PromptTuningConfig(
base_model_name_or_path=model_id,
tokenizer_name_or_path=model_id,
num_virtual_tokens=10,
prompt_tuning_init=PromptTuningInit.TEXT,
task_type="CAUSAL_LM",
prompt_tuning_init_text="This is a test prompt.",
tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"},
)
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
with patch("transformers.AutoTokenizer.from_pretrained", mock_autotokenizer_from_pretrained):
model = get_peft_model(model, config)
expected_call = call(model_id, trust_remote_code=True, foo="bar")
self.assertEqual(mock.call_args, expected_call)
def test_prompt_tuning_config_invalid_args(self):
# Raise an error when tokenizer_kwargs is used with prompt_tuning_init!='TEXT', because this argument has no
# function in that case
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
msg = "tokenizer_kwargs only valid when using prompt_tuning_init='TEXT'."
with self.assertRaisesRegex(ValueError, expected_regex=msg):
PromptTuningConfig(
base_model_name_or_path=model_id,
tokenizer_name_or_path=model_id,
num_virtual_tokens=10,
task_type="CAUSAL_LM",
prompt_tuning_init_text="This is a test prompt.",
prompt_tuning_init=PromptTuningInit.RANDOM, # <= should not be used together with tokenizer_kwargs
tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"},
)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs) self._test_save_pretrained(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
@ -101,6 +165,19 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
self._test_merge_layers(model_id, config_cls, config_kwargs) self._test_merge_layers(model_id, config_cls, config_kwargs)
@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "CAUSAL_LM",
},
)
)
def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs):
self._test_merge_layers_multi(model_id, config_cls, config_kwargs)
@parameterized.expand( @parameterized.expand(
PeftTestConfigManager.get_grid_parameters( PeftTestConfigManager.get_grid_parameters(
{ {
@ -154,6 +231,10 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs) self._test_delete_adapter(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
@ -164,6 +245,7 @@ class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
"model_ids": PEFT_DECODER_MODELS_TO_TEST, "model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]}, "lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "CAUSAL_LM", "task_type": "CAUSAL_LM",
}, },
filter_params_func=skip_adalora_and_gpt2, filter_params_func=skip_adalora_and_gpt2,

View File

@ -70,10 +70,18 @@ class PeftEncoderDecoderModelTester(unittest.TestCase, PeftCommonTester):
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs) self._test_save_pretrained(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
@ -128,6 +136,10 @@ class PeftEncoderDecoderModelTester(unittest.TestCase, PeftCommonTester):
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs) self._test_delete_adapter(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)

View File

@ -146,12 +146,17 @@ class PeftFeatureExtractionModelTester(unittest.TestCase, PeftCommonTester):
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs) self._test_delete_adapter(model_id, config_cls, config_kwargs)
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
@parameterized.expand( @parameterized.expand(
PeftTestConfigManager.get_grid_parameters( PeftTestConfigManager.get_grid_parameters(
{ {
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST, "model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]}, "lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "FEATURE_EXTRACTION", "task_type": "FEATURE_EXTRACTION",
}, },
) )

View File

@ -44,6 +44,7 @@ from peft import (
prepare_model_for_int8_training, prepare_model_for_int8_training,
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from peft.utils import SAFETENSORS_WEIGHTS_NAME
from .testing_utils import ( from .testing_utils import (
require_auto_gptq, require_auto_gptq,
@ -124,6 +125,14 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
self.assertTrue(torch.isfinite(output.logits).all())
model.train(training)
@pytest.mark.single_gpu_tests @pytest.mark.single_gpu_tests
def test_causal_lm_training(self): def test_causal_lm_training(self):
r""" r"""
@ -177,7 +186,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -235,7 +244,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -296,7 +305,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -334,6 +343,8 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
data = load_dataset("ybelkada/english_quotes_copy") data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
self._check_inference_finite(model, batch)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer( trainer = Trainer(
@ -357,7 +368,70 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@pytest.mark.single_gpu_tests
@require_torch_gpu
def test_8bit_adalora_causalLM(self):
r"""
Tests the 8bit training with adalora
"""
model_id = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
peft_config = AdaLoraConfig(
init_r=6,
target_r=4,
tinit=50,
tfinal=100,
deltaT=5,
beta1=0.3,
beta2=0.3,
orth_reg_weight=0.2,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
self._check_inference_finite(model, batch)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -421,7 +495,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -481,7 +555,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -542,7 +616,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -640,7 +714,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -670,6 +744,14 @@ class PeftGPTQGPUTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
self.assertTrue(torch.isfinite(output.logits).all())
model.train(training)
@pytest.mark.single_gpu_tests @pytest.mark.single_gpu_tests
def test_causal_lm_training(self): def test_causal_lm_training(self):
r""" r"""
@ -719,7 +801,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -737,6 +819,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
quantization_config=self.quantization_config, quantization_config=self.quantization_config,
) )
tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
model = prepare_model_for_kbit_training(model) model = prepare_model_for_kbit_training(model)
peft_config = AdaLoraConfig( peft_config = AdaLoraConfig(
@ -758,6 +841,8 @@ class PeftGPTQGPUTests(unittest.TestCase):
data = load_dataset("ybelkada/english_quotes_copy") data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
self._check_inference_finite(model, batch)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer( trainer = Trainer(
@ -781,7 +866,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@ -844,7 +929,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir)) self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))
# assert loss is not None # assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

View File

@ -19,6 +19,7 @@ import unittest
import torch import torch
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model
from peft.utils import ModulesToSaveWrapper
class DummyModel(torch.nn.Module): class DummyModel(torch.nn.Module):
@ -63,3 +64,28 @@ class TestPeft(unittest.TestCase):
for key in peft_state_dict.keys(): for key in peft_state_dict.keys():
self.assertTrue("lora" in key) self.assertTrue("lora" in key)
def test_modules_to_save(self):
self.model = DummyModel()
lora_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
target_modules=["linear"],
modules_to_save=["embedding"],
)
self.model = inject_adapter_in_model(lora_config, self.model)
for name, module in self.model.named_modules():
if name == "linear":
self.assertTrue(hasattr(module, "lora_A"))
self.assertTrue(hasattr(module, "lora_B"))
elif name == "embedding":
self.assertTrue(isinstance(module, ModulesToSaveWrapper))
state_dict = get_peft_model_state_dict(self.model)
self.assertTrue("embedding.weight" in state_dict.keys())

View File

@ -145,7 +145,52 @@ class MultiTaskPromptTuningTester(TestCase, PeftCommonTester):
) )
) )
# check if `adapter_model.bin` is present # check if `adapter_model.safetensors` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")))
# check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
# check if `pytorch_model.bin` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))
# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
def test_save_pretrained_regression(self) -> None:
seed = 420
torch.manual_seed(seed)
model = LlamaForCausalLM(self._create_test_llama_config())
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
model = model.to(self.torch_device)
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname, safe_serialization=False)
torch.manual_seed(seed)
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
# check if the state dicts are equal
state_dict = get_peft_model_state_dict(model)
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
# check if same keys
self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys())
# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
self.assertEqual(len(list(state_dict.keys())), 3)
# check if tensors equal
for key in state_dict.keys():
self.assertTrue(
torch.allclose(
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
)
)
# check if `adapter_model.bin` is present for regression
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
# check if `adapter_config.json` is present # check if `adapter_config.json` is present

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import json import json
import os import os
import pickle import pickle
@ -29,6 +30,7 @@ from peft import (
IA3Config, IA3Config,
LoraConfig, LoraConfig,
PeftModel, PeftModel,
PeftType,
PrefixTuningConfig, PrefixTuningConfig,
PromptEncoderConfig, PromptEncoderConfig,
PromptLearningConfig, PromptLearningConfig,
@ -43,13 +45,6 @@ from peft.utils import _get_submodules, infer_device
from .testing_utils import get_state_dict from .testing_utils import get_state_dict
CONFIG_CLASSES = (
IA3Config,
LoraConfig,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
)
CONFIG_TESTING_KWARGS = ( CONFIG_TESTING_KWARGS = (
# IA³ # IA³
{ {
@ -269,7 +264,7 @@ class PeftCommonTester:
self.assertTrue(dummy_output.requires_grad) self.assertTrue(dummy_output.requires_grad)
def _test_save_pretrained(self, model_id, config_cls, config_kwargs): def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serialization=True):
# ensure that the weights are randomly initialized # ensure that the weights are randomly initialized
if issubclass(config_cls, LoraConfig): if issubclass(config_cls, LoraConfig):
config_kwargs = config_kwargs.copy() config_kwargs = config_kwargs.copy()
@ -287,7 +282,10 @@ class PeftCommonTester:
model = model.to(self.torch_device) model = model.to(self.torch_device)
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname) if safe_serialization:
model.save_pretrained(tmp_dirname)
else:
model.save_pretrained(tmp_dirname, safe_serialization=False)
model_from_pretrained = self.transformers_class.from_pretrained(model_id) model_from_pretrained = self.transformers_class.from_pretrained(model_id)
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
@ -311,14 +309,16 @@ class PeftCommonTester:
) )
) )
# check if `adapter_model.bin` is present target_adapter_filename = "adapter_model.safetensors" if safe_serialization else "adapter_model.bin"
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
# check if `adapter_model.safetensors` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, target_adapter_filename)))
# check if `adapter_config.json` is present # check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
# check if `pytorch_model.bin` is not present # check if `model.safetensors` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
# check if `config.json` is not present # check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
@ -326,7 +326,7 @@ class PeftCommonTester:
self.check_modelcard(tmp_dirname, model) self.check_modelcard(tmp_dirname, model)
self.check_config_json(tmp_dirname, model) self.check_config_json(tmp_dirname, model)
def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs): def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs, safe_serialization=True):
if issubclass(config_cls, AdaLoraConfig): if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter # AdaLora does not support adding more than 1 adapter
return return
@ -355,7 +355,10 @@ class PeftCommonTester:
model.add_adapter("new_adapter", new_adapter_config) model.add_adapter("new_adapter", new_adapter_config)
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname) if safe_serialization:
model.save_pretrained(tmp_dirname)
else:
model.save_pretrained(tmp_dirname, safe_serialization=False)
model_from_pretrained = self.transformers_class.from_pretrained(model_id) model_from_pretrained = self.transformers_class.from_pretrained(model_id)
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
@ -385,17 +388,19 @@ class PeftCommonTester:
) )
) )
# check if `adapter_model.bin` is present target_adapter_filename = "adapter_model.safetensors" if safe_serialization else "adapter_model.bin"
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))
self.assertTrue(os.path.exists(os.path.join(new_adapter_dir, "adapter_model.bin"))) # check if `adapter_model.safetensors` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, target_adapter_filename)))
self.assertTrue(os.path.exists(os.path.join(new_adapter_dir, target_adapter_filename)))
# check if `adapter_config.json` is present # check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))
self.assertTrue(os.path.exists(os.path.join(new_adapter_dir, "adapter_config.json"))) self.assertTrue(os.path.exists(os.path.join(new_adapter_dir, "adapter_config.json")))
# check if `pytorch_model.bin` is not present # check if `model.safetensors` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "model.safetensors")))
self.assertFalse(os.path.exists(os.path.join(new_adapter_dir, "pytorch_model.bin"))) self.assertFalse(os.path.exists(os.path.join(new_adapter_dir, "model.safetensors")))
# check if `config.json` is not present # check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
@ -567,6 +572,71 @@ class PeftCommonTester:
logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0]
self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol))
def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs):
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
if ("gpt2" in model_id.lower()) and (config_cls == IA3Config):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if config.peft_type not in supported_peft_types:
return
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config)
model = model.to(self.torch_device)
dummy_input = self.prepare_inputs_for_testing()
model.eval()
with torch.inference_mode():
logits_adapter_1 = model(**dummy_input)[0]
model.add_adapter("adapter-2", config)
model.set_adapter("adapter-2")
model.eval()
with torch.inference_mode():
logits_adapter_2 = model(**dummy_input)[0]
self.assertFalse(torch.allclose(logits_adapter_1, logits_adapter_2, atol=1e-3, rtol=1e-3))
model.set_adapter("default")
with torch.inference_mode():
logits_adapter_1_after_set = model(**dummy_input)[0]
self.assertTrue(torch.allclose(logits_adapter_1_after_set, logits_adapter_1, atol=1e-3, rtol=1e-3))
model_copy = copy.deepcopy(model)
model_copy_2 = copy.deepcopy(model)
model_merged_all = model.merge_and_unload(adapter_names=["adapter-2", "default"])
with torch.inference_mode():
logits_merged_all = model_merged_all(**dummy_input)[0]
self.assertFalse(torch.allclose(logits_merged_all, logits_adapter_2, atol=1e-3, rtol=1e-3))
self.assertFalse(torch.allclose(logits_merged_all, logits_adapter_1, atol=1e-3, rtol=1e-3))
model_merged_adapter_2 = model_copy.merge_and_unload(adapter_names=["adapter-2"])
with torch.inference_mode():
logits_merged_adapter_2 = model_merged_adapter_2(**dummy_input)[0]
self.assertTrue(torch.allclose(logits_merged_adapter_2, logits_adapter_2, atol=1e-3, rtol=1e-3))
model_merged_adapter_default = model_copy_2.merge_and_unload(adapter_names=["default"])
with torch.inference_mode():
logits_merged_adapter_default = model_merged_adapter_default(**dummy_input)[0]
self.assertTrue(torch.allclose(logits_merged_adapter_default, logits_adapter_1, atol=1e-3, rtol=1e-3))
def _test_generate(self, model_id, config_cls, config_kwargs): def _test_generate(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id) model = self.transformers_class.from_pretrained(model_id)
config = config_cls( config = config_cls(
@ -815,42 +885,79 @@ class PeftCommonTester:
self.assertIsNotNone(param.grad) self.assertIsNotNone(param.grad)
def _test_delete_adapter(self, model_id, config_cls, config_kwargs): def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig): supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
# AdaLora does not support adding more than 1 adapter # IA3 does not support deleting adapters yet, but it just needs to be added
return # AdaLora does not support multiple adapters
model = self.transformers_class.from_pretrained(model_id)
config = config_cls( config = config_cls(
base_model_name_or_path=model_id, base_model_name_or_path=model_id,
**config_kwargs, **config_kwargs,
) )
if config.peft_type not in supported_peft_types:
return
model = self.transformers_class.from_pretrained(model_id)
adapter_to_delete = "delete_me" adapter_to_delete = "delete_me"
model = get_peft_model(model, config) model = get_peft_model(model, config)
model.add_adapter(adapter_to_delete, config) model.add_adapter(adapter_to_delete, config)
model.set_adapter(adapter_to_delete) model.set_adapter(adapter_to_delete)
model = model.to(self.torch_device) model = model.to(self.torch_device)
model.delete_adapter(adapter_to_delete)
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])
if config.peft_type not in ("LORA"): key_list = [key for key, _ in model.named_modules()]
with self.assertRaises(AttributeError): for key in key_list:
model.delete_adapter(adapter_to_delete) _, target, _ = _get_submodules(model, key)
else: attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
model.delete_adapter(adapter_to_delete) for attr in attributes_to_check:
self.assertFalse(adapter_to_delete in model.peft_config) self.assertFalse(adapter_to_delete in getattr(target, attr))
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list: # check that we can also delete the last remaining adapter
_, target, _ = _get_submodules(model, key) model.delete_adapter("default")
if isinstance(target, LoraLayer): self.assertFalse("default" in model.peft_config)
for attr in [ self.assertEqual(model.active_adapters, [])
"r",
"lora_alpha", input = self.prepare_inputs_for_testing()
"scaling", # note: we cannot call model(**input) because PeftModel always expects there to be at least one adapter
"lora_A", model.base_model(**input) # should not raise an error
"lora_B",
"lora_embedding_A", def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
"lora_embedding_B", # same as test_delete_adapter, but this time an inactive adapter is deleted
"lora_dropout", supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
]: # IA3 does not support deleting adapters yet, but it just needs to be added
self.assertFalse(adapter_to_delete in getattr(target, attr)) # AdaLora does not support multiple adapters
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if config.peft_type not in supported_peft_types:
return
model = self.transformers_class.from_pretrained(model_id)
adapter_to_delete = "delete_me"
model = get_peft_model(model, config)
model.add_adapter(adapter_to_delete, config)
# "delete_me" is added but not activated
model = model.to(self.torch_device)
model.delete_adapter(adapter_to_delete)
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
for attr in attributes_to_check:
self.assertFalse(adapter_to_delete in getattr(target, attr))
# check that we can also delete the last remaining adapter
model.delete_adapter("default")
self.assertFalse("default" in model.peft_config)
self.assertEqual(model.active_adapters, [])
input = self.prepare_inputs_for_testing()
# note: we cannot call model(**input) because PeftModel always expects there to be at least one adapter
model.base_model(**input) # should not raise an error
def _test_unload_adapter(self, model_id, config_cls, config_kwargs): def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id) model = self.transformers_class.from_pretrained(model_id)
@ -861,12 +968,12 @@ class PeftCommonTester:
model = get_peft_model(model, config) model = get_peft_model(model, config)
model = model.to(self.torch_device) model = model.to(self.torch_device)
if config.peft_type not in ("LORA", "ADALORA"): if config.peft_type not in ("LORA", "ADALORA", "IA3"):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
model = model.unload() model = model.unload()
else: else:
dummy_input = self.prepare_inputs_for_testing() dummy_input = self.prepare_inputs_for_testing()
logits_with_lora = model(**dummy_input)[0] logits_with_adapter = model(**dummy_input)[0]
transformers_model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) transformers_model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
logits_transformers = transformers_model(**dummy_input)[0] logits_transformers = transformers_model(**dummy_input)[0]
@ -875,7 +982,7 @@ class PeftCommonTester:
model = model.unload() model = model.unload()
logits_unload = model(**dummy_input)[0] logits_unload = model(**dummy_input)[0]
self.assertFalse(torch.allclose(logits_with_lora, logits_unload, atol=1e-10, rtol=1e-10)) self.assertFalse(torch.allclose(logits_with_adapter, logits_unload, atol=1e-10, rtol=1e-10))
self.assertTrue(torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4))
def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs): def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
@ -885,13 +992,14 @@ class PeftCommonTester:
adapter_list = ["adapter1", "adapter_2", "adapter_3"] adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5] weight_list = [0.5, 1.5, 1.5]
model = self.transformers_class.from_pretrained(model_id)
config = config_cls( config = config_cls(
base_model_name_or_path=model_id, base_model_name_or_path=model_id,
**config_kwargs, **config_kwargs,
) )
if not isinstance(config, (LoraConfig)): if not isinstance(config, (LoraConfig)):
return return
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0]) model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config) model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20)) model.add_adapter(adapter_list[2], replace(config, r=20))
@ -930,7 +1038,7 @@ class PeftCommonTester:
for new_adapter in new_adapters: for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config) self.assertTrue(new_adapter in model.peft_config)
key_list = [key for key, _ in model.named_modules() if "lora" not in key] key_list = [key for key, _ in model.named_modules()]
for key in key_list: for key in key_list:
_, target, _ = _get_submodules(model, key) _, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer): if isinstance(target, LoraLayer):
@ -1006,7 +1114,7 @@ class PeftCommonTester:
# must be False # must be False
if isinstance(peft_model, StableDiffusionPipeline): if isinstance(peft_model, StableDiffusionPipeline):
# for SD, check that most pixels have different values # for SD, check that most pixels have different values
self.assertTrue((output_before != output_peft).float().mean() > 0.9) self.assertTrue((output_before != output_peft).float().mean() > 0.8)
else: else:
self.assertFalse(torch.allclose(output_before, output_peft)) self.assertFalse(torch.allclose(output_before, output_peft))