Add LoftQ initialization method for LoRA (#1150)

---------

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
yxli2123
2023-11-29 11:08:17 -05:00
committed by GitHub
parent 8298f1a366
commit 2b901ee572
12 changed files with 1514 additions and 12 deletions

View File

@ -34,6 +34,7 @@ Supported methods:
7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861) 7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861)
8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098) 8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098)
9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation 9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation
10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659)
## Getting started ## Getting started

View File

@ -0,0 +1,69 @@
# LoftQ: LoRA-fine-tuning-aware Quantization
## Introduction
LoftQ provides better initialization for LoRA adapters A and B,
and the Quantization of pre-trained weights W.
## Quantization
We recommend to save the quantized backbone model as fp16/fp32
and load it as [NormalFloat4](https://arxiv.org/abs/2305.14314).
We provide a simple example to show how to quantize llama-2-7b model and save/load it.
```sh
python quantize_save_load.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--token HF_TOKEN \
--bits 4 --iter 5 --rank 16 \
--save_dir model_zoo/loftq/
```
- `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama).
- `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters.
It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank`,
and the LoRA adapters are at the sub-folder `Llama-2-7b-hf-4bit-16rank/loftq_init`.
## Fine-tuning
Here is an example to load the quantized backbone and LoRA adapters:
```python
import os
from transformers import AutoModelForCausalLM
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained(
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank"),
load_in_4bit=True,
)
peft_model = PeftModel.from_pretrained(
base_model,
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank", "loftq_init"),
is_trainable=True,
)
```
We also provide an example to fine-tune LoftQ on GSM8K.
We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub](https://huggingface.co/LoftQ).
```sh
python train_gsm8k_llama.py \
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
--output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
--learning_rate 3e-4 \
--seed 202 \
--dataset_name gsm8k \
--dataset_config main \
--pad_to_max_length \
--max_source_length 128 \
--max_target_length 256 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--with_tracking \
--report_to tensorboard
```

View File

@ -0,0 +1,244 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model
class Shell(nn.Module):
def __init__(self, weight, bias=None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = nn.Parameter(bias, requires_grad=False)
def unwarap_model(model, sub_module_name=".base_layer"):
sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k]
sub_module_name_set = set(sub_module_name_list)
for name in sub_module_name_set:
# get the parent of the submodule
name_parent = ".".join(name.split(".")[:-1])
name_child = name.split(".")[-1]
sub_module = model.get_submodule(name_parent)
print(sub_module)
# replace with shell
child = getattr(sub_module, name_child)
weight = getattr(child.base_layer, "weight", None)
bias = getattr(child.base_layer, "bias", None)
shell = Shell(weight, bias)
setattr(sub_module, name_child, shell)
print("You have unwrapped the model. Use it on your own risk.")
def print_model(model, name):
print("=" * 10 + name + "=" * 10)
print(model)
for name, param in model.named_parameters():
if torch.is_tensor(param):
if param.dtype in [torch.float32, torch.float16]:
print(
name,
param.shape,
param.device,
param.dtype,
param.requires_grad,
param.mean().item(),
param.max().item(),
)
else:
print(name, param.shape, param.device, param.dtype, param.requires_grad)
def arg_parse():
parser = argparse.ArgumentParser(description="Quantize a model with LoftQ.")
parser.add_argument(
"--model_name_or_path",
type=str,
default=None,
required=True,
help="The name or path of the fp32/16 model.",
)
parser.add_argument(
"--token",
type=str,
default=None,
help="The access token to download model from HuggingFace Hub.",
)
parser.add_argument(
"--bits",
type=int,
default=4,
help="The quantized bits",
)
parser.add_argument(
"--iter",
type=int,
default=1,
help="The alternating steps in LoftQ",
)
parser.add_argument(
"--rank",
type=int,
default=16,
help="The rank of the LoRA adapter",
)
parser.add_argument(
"--save_dir",
type=str,
default="./model_zoo/loftq/",
help="The rank of the LoRA adapter",
)
args = parser.parse_args()
return args
def quantize_and_save():
args = arg_parse()
# Download weights and configure LoRA
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token, trust_remote_code=True)
if any(name in args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon"]):
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, token=args.token, trust_remote_code=True, device_map="auto"
)
task_type = TaskType.CAUSAL_LM
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
elif any(name in args.model_name_or_path.lower() for name in ["bart", "t5"]):
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto")
task_type = TaskType.SEQ_2_SEQ_LM
target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"]
elif any(name in args.model_name_or_path.lower() for name in ["deberta", "roberta", "bert"]):
model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token)
model = model.cuda()
task_type = TaskType.SEQ_CLS
target_modules = ["query_proj", "key_proj", "value_proj", "dense"] # embeddings not supported by peft
else:
raise NotImplementedError("Other models not supported yet.")
# Config of LoftQ
loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter)
lora_config = LoraConfig(
task_type=task_type,
inference_mode=True,
r=args.rank,
lora_alpha=16 if task_type is TaskType.CAUSAL_LM else args.rank,
lora_dropout=0.1,
target_modules=target_modules,
init_lora_weights="loftq",
loftq_config=loftq_config,
)
# Obtain LoftQ model
lora_model = get_peft_model(model, lora_config)
base_model = lora_model.get_base_model()
# Save LoftQ model
model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank"
base_model_dir = os.path.join(args.save_dir, model_name)
lora_model_dir = os.path.join(args.save_dir, model_name, "loft_init")
# save lora adapters first
lora_model.base_model.peft_config[
"default"
].base_model_name_or_path = base_model_dir # This can be a local path or Hub model id
lora_model.base_model.peft_config["default"].init_lora_weights = True # Don't apply LoftQ when loading again
lora_model.save_pretrained(lora_model_dir)
print_model(lora_model, "lora_model")
# remove lora adapters and save the backbone
unwarap_model(base_model)
base_model.save_pretrained(base_model_dir)
tokenizer.save_pretrained(base_model_dir)
print_model(base_model, "base_model")
return base_model_dir, lora_model_dir
def load_loftq(base_model_path, lora_adapter_path):
if any(name in base_model_path.lower() for name in ["llama", "mistral", "falcon"]):
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
device_map="auto",
low_cpu_mem_usage=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
)
elif any(name in base_model_path.lower() for name in ["bart", "t5"]):
model = AutoModelForSeq2SeqLM.from_pretrained(
base_model_path,
device_map="auto",
low_cpu_mem_usage=True,
load_in_4bit=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
)
elif any(name in base_model_path.lower() for name in ["deberta", "roberta", "bert"]):
model = AutoModelForSequenceClassification.from_pretrained(
base_model_path,
low_cpu_mem_usage=True,
load_in_4bit=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
)
else:
raise NotImplementedError("Other models not supported yet.")
lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True)
# Do training or inference below
print_model(lora_model, "lora_model")
print_model(model, "base_model")
if __name__ == "__main__":
base_dir, lora_dir = quantize_and_save()
load_loftq(base_dir, lora_dir)
# example command:
# python quantize_save_load.py \
# --model_name_or_path meta-llama/Llama-2-7b-hf \
# --token XXX \
# --bits 4 --iter 5 --rank 16 \
# --save_dir ./model_zoo/loftq/

View File

@ -0,0 +1,866 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import logging
import math
import os
import random
import re
from pathlib import Path
import datasets
import torch
import transformers
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
SchedulerType,
default_data_collator,
get_scheduler,
)
from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
from peft import PeftModel
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# check_min_version("4.32.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
HF_TOKEN = "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"
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--train_file", type=str, default=None, help="A csv, txt or a json file containing the training data."
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv, txt or a json file containing the validation data."
)
parser.add_argument(
"--validation_split_percentage",
default=5,
help="The percentage of the train set used as validation set in case there's no validation split",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=False,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--model_type",
type=str,
default=None,
help="Model type to use if training from scratch.",
choices=MODEL_TYPES,
)
parser.add_argument(
"--ignore_pad_token_for_loss",
type=bool,
default=True,
help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
)
parser.add_argument(
"--max_source_length",
type=int,
default=128,
help=(
"The maximum total input sequence length after "
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--max_target_length",
type=int,
default=128,
help=(
"The maximum total sequence length for target text after "
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
"during ``evaluate`` and ``predict``."
),
)
parser.add_argument(
"--pad_to_max_length",
action="store_true",
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
)
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--trust_remote_code",
type=bool,
default=False,
help=(
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
),
)
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--low_cpu_mem_usage",
action="store_true",
help=(
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
"If passed, LLM loading time and RAM consumption will be benefited."
),
)
##########################
# Generation Config #
##########################
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument("--k", type=int, default=40, help="Choose k candidate words")
parser.add_argument("--p", type=float, default=0.95, help="The sum of probability of candidate words is 0.9 ")
##########################
# Exp Args #
##########################
parser.add_argument(
"--adapter_name_or_path",
type=str,
default=None,
help=(
"The LoRA adapter checkpoint. Set None if you want to fine-tune from LoftQ."
"Specify a path if you want to evaluate."
),
)
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
return args
def main():
args = parse_args()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clm_no_trainer", args)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator_log_kwargs = {}
if args.with_tracking:
accelerator_log_kwargs["log_with"] = args.report_to
accelerator_log_kwargs["project_dir"] = args.output_dir
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
**dataset_args,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if args.config_name:
config = AutoConfig.from_pretrained(
args.config_name,
trust_remote_code=args.trust_remote_code,
)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(
args.model_name_or_path,
trust_remote_code=args.trust_remote_code,
)
else:
config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code
)
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
use_fast=not args.use_slow_tokenizer,
trust_remote_code=args.trust_remote_code,
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
##########################
# Tokenizer #
##########################
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
tokenizer.padding_side = "left" # Allow batched inference
tokenizer.truncation_side = "left"
if args.model_name_or_path:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
low_cpu_mem_usage=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=config.torch_dtype,
),
)
else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config, trust_remote_code=args.trust_remote_code)
##########################
# Peft Model #
##########################
if args.adapter_name_or_path is None:
model = PeftModel.from_pretrained(model, args.model_name_or_path, subfolder="loftq_init", is_trainable=True)
else:
model = PeftModel.from_pretrained(model, args.adapter_name_or_path, is_trainable=True)
model.print_trainable_parameters()
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
# First we tokenize all the texts.
##########################
# GSM8K dataset #
##########################
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
# Get the column names for source/target.
source_column, target_column = "question", "answer"
# Temporarily set max_target_length for training.
padding = "max_length" if args.pad_to_max_length else False
task_prompt = "\nAnswer the above question. First think step by step and then answer the final number.\n"
def prompt_process(sent_1, sent_2, prompt_1="", prompt_2="", prompt_3=""):
sent_2 = sent_2.replace("####", "The final answer is")
return prompt_1 + sent_1 + prompt_2 + sent_2 + prompt_3
def preprocess_function_train(examples):
sources = examples[source_column]
targets = examples[target_column]
inputs = [prompt_process(source, target, prompt_2=task_prompt) for (source, target) in zip(sources, targets)]
model_inputs = tokenizer(
inputs,
max_length=args.max_source_length + args.max_target_length,
padding=padding,
truncation=True,
return_tensors="pt",
)
labels = copy.deepcopy(model_inputs)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and args.ignore_pad_token_for_loss:
# get the length of the target tokens. -1 to kick out the <BOS> token
target_tokens = tokenizer(targets, padding=False)
target_len = [len(label) - 1 for label in target_tokens["input_ids"]]
# don't calculate the loss from source and padding (left padding)
for i in range(len(labels["input_ids"])):
labels["input_ids"][i, : -target_len[i]] = -100
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def preprocess_function_test(examples):
sources = examples[source_column]
labels = examples[target_column]
inputs = [source + task_prompt for source in sources]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
labels = tokenizer(labels, max_length=args.max_target_length, padding=padding, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
with accelerator.main_process_first():
train_dataset = raw_datasets["train"].map(
preprocess_function_train,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on training dataset",
)
eval_dataset = raw_datasets["test"].map(
preprocess_function_test,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on test dataset",
)
# Log a few random samples from the set:
for index in random.sample(range(len(train_dataset)), 2):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
for index in random.sample(range(len(eval_dataset)), 2):
logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.")
# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
)
eval_dataloader = DataLoader(
eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and "lora" in n],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("clm_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
checkpoint_path = args.resume_from_checkpoint
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
completed_steps = resume_step // args.gradient_accumulation_steps
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
accelerator.print(f"Epoch: {epoch} | Step: {step} | Loss: {loss}")
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps:
break
model.eval()
gen_kwargs = {
"max_new_tokens": args.max_target_length,
"temperature": args.temperature,
"top_k": args.k,
"top_p": args.p,
"do_sample": True,
}
ans_pred_list = []
ans_gold_list = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
gen_kwargs["input_ids"] = batch["input_ids"]
gen_kwargs["attention_mask"] = batch["attention_mask"]
generated_tokens = accelerator.unwrap_model(model).generate(**gen_kwargs)
pred_tokens = generated_tokens[:, args.max_source_length :]
pred_tokens = accelerator.pad_across_processes(pred_tokens, dim=1, pad_index=tokenizer.pad_token_id)
gold_tokens = batch["labels"]
if not args.pad_to_max_length:
# If we did not pad to max length, we need to pad the labels too
gold_tokens = accelerator.pad_across_processes(
batch["labels"], dim=1, pad_index=tokenizer.pad_token_id
)
pred_tokens, gold_tokens = accelerator.gather_for_metrics((pred_tokens, gold_tokens))
pred_tokens, gold_tokens = pred_tokens.cpu().numpy(), gold_tokens.cpu().numpy()
if isinstance(pred_tokens, tuple):
pred_tokens = pred_tokens[0]
decoded_pred = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)
decoded_gold = tokenizer.batch_decode(gold_tokens, skip_special_tokens=True)
# Extract the numbers in sentences
accelerator.print(decoded_pred)
ans_pred_list += [extract_answer_number(sentence_pred) for sentence_pred in decoded_pred]
ans_gold_list += [extract_answer_number(sentence_gold) for sentence_gold in decoded_gold]
accelerator.print(ans_pred_list)
accelerator.print(ans_gold_list)
accuracy = compute_accuracy(ans_gold_list, ans_pred_list)
logger.info(f"epoch {epoch}: accuracy: {accuracy}")
if args.with_tracking:
accelerator.log(
{
"accuracy": accuracy,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
PATTERN_NUMBER = re.compile(r"-?\d+\.?\d*")
def extract_answer_number(sentence: str) -> float:
sentence = sentence.replace(",", "")
pred = PATTERN_NUMBER.findall(sentence)
if not pred:
return float("inf")
segment = sentence.split("The final answer is ")
if len(segment) > 1:
pred_answer = segment[1]
pred_answer = PATTERN_NUMBER.findall(pred_answer)
if len(pred_answer) > 0:
pred_answer = pred_answer[0]
else:
pred_answer = float(pred[-1])
else:
pred_answer = float(pred[-1])
if isinstance(pred_answer, str):
try:
pred_answer = float(pred_answer)
except ValueError:
pred_answer = float("inf")
return pred_answer
def compute_accuracy(pred: list, gold: list):
acc = 0.0
for p, g in zip(pred, gold):
if p == g:
acc += 1
return acc / len(pred)
if __name__ == "__main__":
main()
# example command
# python train_gsm8k_llama.py \
# --model_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-backbone \
# --adapter_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-adapters \
# --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
# --learning_rate 1e-4 \
# --seed 202 \
# --dataset_name gsm8k \
# --dataset_config main \
# --pad_to_max_length \
# --max_source_length 128 \
# --max_target_length 256 \
# --num_train_epochs 5 \
# --per_device_train_batch_size 4 \
# --per_device_eval_batch_size 4 \
# --gradient_accumulation_steps 4 \
# --with_tracking \
# --report_to tensorboard

15
requirements.txt Normal file
View File

@ -0,0 +1,15 @@
accelerate
torch
safetensors
bitsandbytes
scipy
peft
transformers
tqdm
packaging
pytest
numpy
pyyaml
datasets
psutil
setuptools

View File

@ -48,6 +48,7 @@ from .tuners import (
AdaptionPromptConfig, AdaptionPromptConfig,
AdaptionPromptModel, AdaptionPromptModel,
LoraConfig, LoraConfig,
LoftQConfig,
LoraModel, LoraModel,
LoHaConfig, LoHaConfig,
LoHaModel, LoHaModel,

View File

@ -18,7 +18,7 @@
# limitations under the License. # limitations under the License.
from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
from .lora import LoraConfig, LoraModel from .lora import LoraConfig, LoraModel, LoftQConfig
from .loha import LoHaConfig, LoHaModel from .loha import LoHaConfig, LoHaModel
from .lokr import LoKrConfig, LoKrModel from .lokr import LoKrConfig, LoKrModel
from .ia3 import IA3Config, IA3Model from .ia3 import IA3Config, IA3Model

View File

@ -15,13 +15,13 @@
from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from .config import LoraConfig from .config import LoftQConfig, LoraConfig
from .gptq import QuantLinear from .gptq import QuantLinear
from .layer import Conv2d, Embedding, Linear, LoraLayer from .layer import Conv2d, Embedding, Linear, LoraLayer
from .model import LoraModel from .model import LoraModel
__all__ = ["LoraConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] __all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"]
if is_bnb_available(): if is_bnb_available():

View File

@ -22,6 +22,25 @@ from peft.config import PeftConfig
from peft.utils import PeftType from peft.utils import PeftType
@dataclass
class LoftQConfig:
"""
This is the sub-configuration class to store the configuration of a [`LoraModel`].
Args:
bits_pattern (`dict`): The mapping from layer names or regexp expression to bits which are different from the
default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}.
bits (`int`): Quantization bits for LoftQ.
iter (`int`): Alternating iterations for LoftQ.
fake (`bool`): True: use fp16/fp32; used for first time to save weights. False: use bitsandbytes 4bit linear
models. weights can't be saved. Recommend to set to True, save the weights and load the saved weights in 4
bits.
"""
loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"})
loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
@dataclass @dataclass
class LoraConfig(PeftConfig): class LoraConfig(PeftConfig):
""" """
@ -78,7 +97,7 @@ class LoraConfig(PeftConfig):
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
}, },
) )
init_lora_weights: bool | Literal["gaussian"] = field( init_lora_weights: bool | Literal["gaussian", "loftq"] = field(
default=True, default=True,
metadata={ metadata={
"help": ( "help": (
@ -86,6 +105,7 @@ class LoraConfig(PeftConfig):
"initialization from the reference implementation from Microsoft. Passing 'gaussian' results " "initialization from the reference implementation from Microsoft. Passing 'gaussian' results "
"in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization " "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
"to False leads to completely random initialization and is discouraged." "to False leads to completely random initialization and is discouraged."
"Pass `'loftq'` to use LoftQ initialization"
), ),
}, },
) )
@ -121,6 +141,16 @@ class LoraConfig(PeftConfig):
) )
}, },
) )
# dict type is used when loading config.json
loftq_config: Union[LoftQConfig, dict] = field(
default_factory=dict,
metadata={
"help": (
"The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone "
"weights and initialize Lora layers."
)
},
)
def __post_init__(self): def __post_init__(self):
self.peft_type = PeftType.LORA self.peft_type = PeftType.LORA
@ -134,3 +164,16 @@ class LoraConfig(PeftConfig):
# if target_modules is a regex expression, then layers_pattern should be None # if target_modules is a regex expression, then layers_pattern should be None
if isinstance(self.target_modules, str) and self.layers_pattern is not None: if isinstance(self.target_modules, str) and self.layers_pattern is not None:
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")
# handle init_lora_weights and loftq_config
if self.init_lora_weights == "loftq":
import importlib
if not importlib.util.find_spec("scipy"):
raise ImportError("The required package 'scipy' is not installed. Please install it to continue.")
if self.loftq_config is None:
raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")
# convert loftq_config to dict
if self.loftq_config is not None and not isinstance(self.loftq_config, dict):
self.loftq_config = vars(self.loftq_config)

View File

@ -15,7 +15,7 @@
import math import math
import warnings import warnings
from typing import Any, List, Optional from typing import Any, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -46,6 +46,7 @@ class LoraLayer(BaseTunerLayer):
# Mark the weight as unmerged # Mark the weight as unmerged
self._disable_adapters = False self._disable_adapters = False
self.merged_adapters = [] self.merged_adapters = []
self.kwargs = kwargs
base_layer = self.get_base_layer() base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear): if isinstance(base_layer, nn.Linear):
@ -83,7 +84,10 @@ class LoraLayer(BaseTunerLayer):
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
self.scaling[adapter_name] = lora_alpha / r self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights) self.reset_lora_parameters(adapter_name, init_lora_weights)
weight = getattr(self.get_base_layer(), "weight", None) weight = getattr(self.get_base_layer(), "weight", None)
@ -115,7 +119,10 @@ class LoraLayer(BaseTunerLayer):
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
self.scaling[adapter_name] = lora_alpha / r self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights) self.reset_lora_parameters(adapter_name, init_lora_weights)
weight = getattr(base_layer, "weight", None) weight = getattr(base_layer, "weight", None)
@ -142,6 +149,10 @@ class LoraLayer(BaseTunerLayer):
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A) self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B) self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
self.scaling[adapter_name] = lora_alpha / r self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights) self.reset_lora_parameters(adapter_name, init_lora_weights)
base_layer = self.get_base_layer() base_layer = self.get_base_layer()
@ -170,6 +181,27 @@ class LoraLayer(BaseTunerLayer):
nn.init.zeros_(self.lora_embedding_A[adapter_name]) nn.init.zeros_(self.lora_embedding_A[adapter_name])
nn.init.normal_(self.lora_embedding_B[adapter_name]) nn.init.normal_(self.lora_embedding_B[adapter_name])
def loftq_init(self, adapter_name):
from peft.utils.loftq_utils import loftq_init
weight = self.get_base_layer().weight
kwargs = {
"num_bits": self.kwargs.get("loftq_bits", 4),
"reduced_rank": self.r[adapter_name],
"num_iter": self.kwargs.get("loftq_iter", 1),
}
qweight, lora_A, lora_B = loftq_init(weight, **kwargs)
if adapter_name in self.lora_A.keys():
# initialize A the same way as the default for nn.Linear and B to zero
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
if adapter_name in self.lora_embedding_A.keys():
# initialize a the same way as the default for nn.linear and b to zero
self.lora_embedding_A[adapter_name].weight.data = lora_A
self.lora_embedding_B[adapter_name].weight.data = lora_B
self.get_base_layer().weight.data = qweight
def set_scale(self, adapter, scale): def set_scale(self, adapter, scale):
if adapter not in self.scaling: if adapter not in self.scaling:
# Ignore the case where the adapter is not in the layer # Ignore the case where the adapter is not in the layer
@ -218,11 +250,11 @@ class Linear(nn.Module, LoraLayer):
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_target_conv_1d_layer: bool = False, is_target_conv_1d_layer: bool = False,
init_lora_weights: bool = True, init_lora_weights: Union[bool, str] = True,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
LoraLayer.__init__(self, base_layer) LoraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name self._active_adapter = adapter_name
@ -351,7 +383,7 @@ class Embedding(nn.Module, LoraLayer):
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True, init_lora_weights: Union[bool, str] = True,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -491,7 +523,7 @@ class Conv2d(nn.Module, LoraLayer):
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
init_lora_weights: bool = True, init_lora_weights: Union[bool, str] = True,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()

View File

@ -286,8 +286,10 @@ class LoraModel(BaseTuner):
elif isinstance(target_base_layer, torch.nn.Embedding): elif isinstance(target_base_layer, torch.nn.Embedding):
embedding_kwargs = kwargs.copy() embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None) embedding_kwargs.pop("fan_in_fan_out", None)
embedding_kwargs.update(lora_config.loftq_config)
new_module = Embedding(target, adapter_name, **embedding_kwargs) new_module = Embedding(target, adapter_name, **embedding_kwargs)
elif isinstance(target_base_layer, torch.nn.Conv2d): elif isinstance(target_base_layer, torch.nn.Conv2d):
kwargs.update(lora_config.loftq_config)
new_module = Conv2d(target, adapter_name, **kwargs) new_module = Conv2d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear): elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]: if kwargs["fan_in_fan_out"]:
@ -296,6 +298,7 @@ class LoraModel(BaseTuner):
"Setting fan_in_fan_out to False." "Setting fan_in_fan_out to False."
) )
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, **kwargs) new_module = Linear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, Conv1D): elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]: if not kwargs["fan_in_fan_out"]:
@ -304,6 +307,7 @@ class LoraModel(BaseTuner):
"Setting fan_in_fan_out to True." "Setting fan_in_fan_out to True."
) )
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
else: else:
raise ValueError( raise ValueError(

View File

@ -0,0 +1,227 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Reference code: https://github.com/yxli2123/LoftQ/blob/main/utils.py
# Reference paper: https://arxiv.org/abs/2310.08659
import logging
from typing import Union
import torch
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
if is_bnb_available():
import bitsandbytes as bnb
class NFQuantizer:
def __init__(self, num_bits=2, device="cuda", method="normal", block_size=64, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_bits = num_bits
self.device = device
self.method = method
self.block_size = block_size
if self.method == "normal":
self.norm_lookup_table = self.create_normal_map(num_bits=self.num_bits)
self.norm_lookup_table = self.norm_lookup_table.to(device)
elif self.method == "uniform":
self.norm_lookup_table = self.create_uniform_map(num_bits=self.num_bits)
self.norm_lookup_table = self.norm_lookup_table.to(device)
else:
raise NotImplementedError("Other quantization methods not supported yet.")
@staticmethod
def create_uniform_map(symmetric=False, num_bits=4):
if symmetric:
# print("symmetric uniform quantization")
negative = torch.linspace(-1, 0, 2 ** (num_bits - 1))
positive = torch.linspace(0, 1, 2 ** (num_bits - 1))
table = torch.cat([negative, positive[1:]])
else:
# print("asymmetric uniform quantization")
table = torch.linspace(-1, 1, 2**num_bits)
return table
@staticmethod
def create_normal_map(offset=0.9677083, symmetric=False, num_bits=2):
try:
from scipy.stats import norm
except ImportError:
raise ImportError("The required package 'scipy' is not installed. Please install it to continue.")
variations = 2**num_bits
if symmetric:
v = norm.ppf(torch.linspace(1 - offset, offset, variations + 1)).tolist()
values = []
for index in range(len(v) - 1):
values.append(0.5 * v[index] + 0.5 * v[index + 1])
v = values
else:
# one more positive value, this is an asymmetric type
v1 = norm.ppf(torch.linspace(offset, 0.5, variations // 2 + 1)[:-1]).tolist()
v2 = [0]
v3 = (-norm.ppf(torch.linspace(offset, 0.5, variations // 2)[:-1])).tolist()
v = v1 + v2 + v3
values = torch.Tensor(v)
values = values.sort().values
values /= values.max()
return values
def quantize_tensor(self, weight):
max_abs = torch.abs(weight).max()
weight_normed = weight / max_abs
weight_normed_expanded = weight_normed.unsqueeze(-1)
# Reshape L to have the same number of dimensions as X_expanded
L_reshaped = torch.tensor(self.norm_lookup_table).reshape(1, -1)
# Calculate the absolute difference between X_expanded and L_reshaped
abs_diff = torch.abs(weight_normed_expanded - L_reshaped)
# Find the index of the minimum absolute difference for each element
qweight = torch.argmin(abs_diff, dim=-1)
return qweight, max_abs
def dequantize_tensor(self, qweight, max_abs):
qweight_flatten = qweight.flatten()
weight_normed = self.norm_lookup_table[qweight_flatten]
weight = weight_normed * max_abs
weight = weight.reshape(qweight.shape)
return weight
def quantize_block(self, weight):
if len(weight.shape) != 2:
raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.")
if weight.shape[0] * weight.shape[1] % self.block_size != 0:
raise ValueError(
f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) "
f"is not dividable by block size {self.block_size}."
)
M, N = weight.shape
device = weight.device
# Quantization
weight_flatten = weight.flatten() # (M*N, )
weight_block = weight_flatten.reshape(-1, self.block_size) # (L, B), L = M * N / B
if self.method == "normal":
weight_max = weight_block.abs().max(dim=-1)[0] # (L, 1)
elif self.method == "uniform":
weight_max = weight_block.mean(dim=-1) + 2.5 * weight_block.std(dim=-1)
else:
raise NotImplementedError("Method not supported yet.")
weight_max = weight_max.unsqueeze(-1)
weight_divabs = weight_block / weight_max # (L, B)
weight_divabs = weight_divabs.unsqueeze(-1) # (L, B, 1)
L_reshaped = self.norm_lookup_table.reshape(1, -1) # (1, 2**K)
abs_diff = torch.abs(weight_divabs - L_reshaped) # (L, B, 2**K)
qweight = torch.argmin(abs_diff, dim=-1) # (L, B)
# Pack multiple k-bit into uint8
qweight = qweight.reshape(-1, 8 // self.num_bits)
qweight_pack = torch.zeros((M * N // 8 * self.num_bits, 1), dtype=torch.uint8, device=device)
# data format example:
# [1, 0, 3, 2] or [01, 00, 11, 10] -> [10110001], LIFO
for i in range(8 // self.num_bits):
qweight[:, i] = qweight[:, i] << i * self.num_bits
qweight_pack[:, 0] |= qweight[:, i]
return qweight_pack, weight_max, weight.shape
def dequantize_block(self, qweight, weight_max, weight_shape):
# unpack weight
device = qweight.device
weight = torch.zeros((qweight.shape[0], 8 // self.num_bits), dtype=torch.float32, device=device)
for i in range(8 // self.num_bits):
lookup_table_idx = qweight.to(torch.long) % 2**self.num_bits # get the most right 2 bits
lookup_table_idx = lookup_table_idx.to(torch.int)
weight[:, i] = self.norm_lookup_table[lookup_table_idx].squeeze()
qweight = qweight >> self.num_bits # right shift 2 bits of the original data
weight_block = weight.reshape(-1, self.block_size)
weight = weight_block * weight_max
weight = weight.reshape(weight_shape)
return weight
def _low_rank_decomposition(weight, reduced_rank=32):
"""
:param weight: The matrix to decompose, of shape (H, W) :param reduced_rank: the final rank :return:
"""
matrix_dimension = len(weight.size())
if matrix_dimension != 2:
raise ValueError(f"Only support 2D matrix, but your input has {matrix_dimension} dimensions.")
# Use SVD to decompose a matrix, default full_matrices is False to save parameters
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
L = U @ (torch.sqrt(torch.diag(S)[:, 0:reduced_rank]))
R = torch.sqrt(torch.diag(S)[0:reduced_rank, :]) @ Vh
return {"L": L, "R": R, "U": U, "S": S, "Vh": Vh, "reduced_rank": reduced_rank}
@torch.no_grad()
def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, reduced_rank: int, num_iter=1):
if num_bits not in [2, 4, 8]:
raise ValueError("Only support 2, 4, 8 bits quantization")
if num_iter <= 0:
raise ValueError("Number of iterations must be greater than 0")
out_feature, in_feature = weight.size()
device = weight.device
dtype = weight.dtype
logging.info(
f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} "
f"| Num Iter: {num_iter} | Num Bits: {num_bits}"
)
if not is_bnb_4bit_available():
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
weight = weight.to(torch.float32)
res = weight.clone()
for i in range(num_iter):
torch.cuda.empty_cache()
# Quantization
if num_bits == 4 and is_bnb_4bit_available():
qweight = bnb.nn.Params4bit(
res.to("cpu"), requires_grad=False, compress_statistics=False, quant_type="nf4"
).to(device)
dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state)
else:
quantized_weight, max_abs, shape = quantizer.quantize_block(res)
dequantized_weight = quantizer.dequantize_block(quantized_weight, max_abs, shape)
res = weight - dequantized_weight
# Decompose the residual by SVD
output = _low_rank_decomposition(res, reduced_rank=reduced_rank)
L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"]
res = weight - torch.mm(L, R)
lora_A, lora_B = R, L
return dequantized_weight.to(dtype), lora_A, lora_B