mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Add LoftQ initialization method for LoRA (#1150)
--------- Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
@ -34,6 +34,7 @@ Supported methods:
|
|||||||
7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861)
|
7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861)
|
||||||
8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098)
|
8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098)
|
||||||
9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation
|
9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation
|
||||||
|
10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659)
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
|
69
examples/loftq_finetuning/README.md
Normal file
69
examples/loftq_finetuning/README.md
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# LoftQ: LoRA-fine-tuning-aware Quantization
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
LoftQ provides better initialization for LoRA adapters A and B,
|
||||||
|
and the Quantization of pre-trained weights W.
|
||||||
|
|
||||||
|
## Quantization
|
||||||
|
We recommend to save the quantized backbone model as fp16/fp32
|
||||||
|
and load it as [NormalFloat4](https://arxiv.org/abs/2305.14314).
|
||||||
|
|
||||||
|
We provide a simple example to show how to quantize llama-2-7b model and save/load it.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python quantize_save_load.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--token HF_TOKEN \
|
||||||
|
--bits 4 --iter 5 --rank 16 \
|
||||||
|
--save_dir model_zoo/loftq/
|
||||||
|
```
|
||||||
|
|
||||||
|
- `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama).
|
||||||
|
- `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters.
|
||||||
|
It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank`,
|
||||||
|
and the LoRA adapters are at the sub-folder `Llama-2-7b-hf-4bit-16rank/loftq_init`.
|
||||||
|
|
||||||
|
## Fine-tuning
|
||||||
|
|
||||||
|
Here is an example to load the quantized backbone and LoRA adapters:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank"),
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
peft_model = PeftModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank", "loftq_init"),
|
||||||
|
is_trainable=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
We also provide an example to fine-tune LoftQ on GSM8K.
|
||||||
|
We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub](https://huggingface.co/LoftQ).
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python train_gsm8k_llama.py \
|
||||||
|
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
|
||||||
|
--output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
|
||||||
|
--learning_rate 3e-4 \
|
||||||
|
--seed 202 \
|
||||||
|
--dataset_name gsm8k \
|
||||||
|
--dataset_config main \
|
||||||
|
--pad_to_max_length \
|
||||||
|
--max_source_length 128 \
|
||||||
|
--max_target_length 256 \
|
||||||
|
--num_train_epochs 5 \
|
||||||
|
--per_device_train_batch_size 4 \
|
||||||
|
--per_device_eval_batch_size 4 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--with_tracking \
|
||||||
|
--report_to tensorboard
|
||||||
|
```
|
244
examples/loftq_finetuning/quantize_save_load.py
Normal file
244
examples/loftq_finetuning/quantize_save_load.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model
|
||||||
|
|
||||||
|
|
||||||
|
class Shell(nn.Module):
|
||||||
|
def __init__(self, weight, bias=None):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias = nn.Parameter(bias, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
|
def unwarap_model(model, sub_module_name=".base_layer"):
|
||||||
|
sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k]
|
||||||
|
sub_module_name_set = set(sub_module_name_list)
|
||||||
|
for name in sub_module_name_set:
|
||||||
|
# get the parent of the submodule
|
||||||
|
name_parent = ".".join(name.split(".")[:-1])
|
||||||
|
name_child = name.split(".")[-1]
|
||||||
|
sub_module = model.get_submodule(name_parent)
|
||||||
|
print(sub_module)
|
||||||
|
|
||||||
|
# replace with shell
|
||||||
|
child = getattr(sub_module, name_child)
|
||||||
|
weight = getattr(child.base_layer, "weight", None)
|
||||||
|
bias = getattr(child.base_layer, "bias", None)
|
||||||
|
shell = Shell(weight, bias)
|
||||||
|
|
||||||
|
setattr(sub_module, name_child, shell)
|
||||||
|
|
||||||
|
print("You have unwrapped the model. Use it on your own risk.")
|
||||||
|
|
||||||
|
|
||||||
|
def print_model(model, name):
|
||||||
|
print("=" * 10 + name + "=" * 10)
|
||||||
|
print(model)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if torch.is_tensor(param):
|
||||||
|
if param.dtype in [torch.float32, torch.float16]:
|
||||||
|
print(
|
||||||
|
name,
|
||||||
|
param.shape,
|
||||||
|
param.device,
|
||||||
|
param.dtype,
|
||||||
|
param.requires_grad,
|
||||||
|
param.mean().item(),
|
||||||
|
param.max().item(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(name, param.shape, param.device, param.dtype, param.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def arg_parse():
|
||||||
|
parser = argparse.ArgumentParser(description="Quantize a model with LoftQ.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The name or path of the fp32/16 model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The access token to download model from HuggingFace Hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bits",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="The quantized bits",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The alternating steps in LoftQ",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rank",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The rank of the LoRA adapter",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
default="./model_zoo/loftq/",
|
||||||
|
help="The rank of the LoRA adapter",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_and_save():
|
||||||
|
args = arg_parse()
|
||||||
|
|
||||||
|
# Download weights and configure LoRA
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token, trust_remote_code=True)
|
||||||
|
if any(name in args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon"]):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.model_name_or_path, token=args.token, trust_remote_code=True, device_map="auto"
|
||||||
|
)
|
||||||
|
task_type = TaskType.CAUSAL_LM
|
||||||
|
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
|
||||||
|
|
||||||
|
elif any(name in args.model_name_or_path.lower() for name in ["bart", "t5"]):
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto")
|
||||||
|
task_type = TaskType.SEQ_2_SEQ_LM
|
||||||
|
target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"]
|
||||||
|
|
||||||
|
elif any(name in args.model_name_or_path.lower() for name in ["deberta", "roberta", "bert"]):
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token)
|
||||||
|
model = model.cuda()
|
||||||
|
task_type = TaskType.SEQ_CLS
|
||||||
|
target_modules = ["query_proj", "key_proj", "value_proj", "dense"] # embeddings not supported by peft
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Other models not supported yet.")
|
||||||
|
|
||||||
|
# Config of LoftQ
|
||||||
|
loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter)
|
||||||
|
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type=task_type,
|
||||||
|
inference_mode=True,
|
||||||
|
r=args.rank,
|
||||||
|
lora_alpha=16 if task_type is TaskType.CAUSAL_LM else args.rank,
|
||||||
|
lora_dropout=0.1,
|
||||||
|
target_modules=target_modules,
|
||||||
|
init_lora_weights="loftq",
|
||||||
|
loftq_config=loftq_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Obtain LoftQ model
|
||||||
|
lora_model = get_peft_model(model, lora_config)
|
||||||
|
base_model = lora_model.get_base_model()
|
||||||
|
|
||||||
|
# Save LoftQ model
|
||||||
|
model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank"
|
||||||
|
base_model_dir = os.path.join(args.save_dir, model_name)
|
||||||
|
lora_model_dir = os.path.join(args.save_dir, model_name, "loft_init")
|
||||||
|
|
||||||
|
# save lora adapters first
|
||||||
|
lora_model.base_model.peft_config[
|
||||||
|
"default"
|
||||||
|
].base_model_name_or_path = base_model_dir # This can be a local path or Hub model id
|
||||||
|
lora_model.base_model.peft_config["default"].init_lora_weights = True # Don't apply LoftQ when loading again
|
||||||
|
|
||||||
|
lora_model.save_pretrained(lora_model_dir)
|
||||||
|
print_model(lora_model, "lora_model")
|
||||||
|
|
||||||
|
# remove lora adapters and save the backbone
|
||||||
|
unwarap_model(base_model)
|
||||||
|
base_model.save_pretrained(base_model_dir)
|
||||||
|
tokenizer.save_pretrained(base_model_dir)
|
||||||
|
|
||||||
|
print_model(base_model, "base_model")
|
||||||
|
|
||||||
|
return base_model_dir, lora_model_dir
|
||||||
|
|
||||||
|
|
||||||
|
def load_loftq(base_model_path, lora_adapter_path):
|
||||||
|
if any(name in base_model_path.lower() for name in ["llama", "mistral", "falcon"]):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model_path,
|
||||||
|
device_map="auto",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=False,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif any(name in base_model_path.lower() for name in ["bart", "t5"]):
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
base_model_path,
|
||||||
|
device_map="auto",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
load_in_4bit=True,
|
||||||
|
quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=False,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif any(name in base_model_path.lower() for name in ["deberta", "roberta", "bert"]):
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
base_model_path,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
load_in_4bit=True,
|
||||||
|
quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=False,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Other models not supported yet.")
|
||||||
|
|
||||||
|
lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True)
|
||||||
|
|
||||||
|
# Do training or inference below
|
||||||
|
print_model(lora_model, "lora_model")
|
||||||
|
print_model(model, "base_model")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
base_dir, lora_dir = quantize_and_save()
|
||||||
|
load_loftq(base_dir, lora_dir)
|
||||||
|
|
||||||
|
# example command:
|
||||||
|
# python quantize_save_load.py \
|
||||||
|
# --model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
# --token XXX \
|
||||||
|
# --bits 4 --iter 5 --rank 16 \
|
||||||
|
# --save_dir ./model_zoo/loftq/
|
866
examples/loftq_finetuning/train_gsm8k_llama.py
Normal file
866
examples/loftq_finetuning/train_gsm8k_llama.py
Normal file
@ -0,0 +1,866 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import (
|
||||||
|
CONFIG_MAPPING,
|
||||||
|
MODEL_MAPPING,
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
SchedulerType,
|
||||||
|
default_data_collator,
|
||||||
|
get_scheduler,
|
||||||
|
)
|
||||||
|
from transformers.utils import send_example_telemetry
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
|
||||||
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
# check_min_version("4.32.0.dev0")
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||||
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
HF_TOKEN = "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"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The name of the dataset to use (via the datasets library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_config_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The configuration name of the dataset to use (via the datasets library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_file", type=str, default=None, help="A csv, txt or a json file containing the training data."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_file", type=str, default=None, help="A csv, txt or a json file containing the validation data."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_split_percentage",
|
||||||
|
default=5,
|
||||||
|
help="The percentage of the train set used as validation set in case there's no validation split",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained config name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_slow_tokenizer",
|
||||||
|
action="store_true",
|
||||||
|
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_device_train_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the training dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_device_eval_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the evaluation dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||||
|
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_scheduler_type",
|
||||||
|
type=SchedulerType,
|
||||||
|
default="linear",
|
||||||
|
help="The scheduler type to use.",
|
||||||
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||||
|
)
|
||||||
|
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model type to use if training from scratch.",
|
||||||
|
choices=MODEL_TYPES,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ignore_pad_token_for_loss",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_source_length",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help=(
|
||||||
|
"The maximum total input sequence length after "
|
||||||
|
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_target_length",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help=(
|
||||||
|
"The maximum total sequence length for target text after "
|
||||||
|
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||||
|
"during ``evaluate`` and ``predict``."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pad_to_max_length",
|
||||||
|
action="store_true",
|
||||||
|
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocessing_num_workers",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="The number of processes to use for the preprocessing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
|
||||||
|
)
|
||||||
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||||
|
)
|
||||||
|
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust_remote_code",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help=(
|
||||||
|
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
|
||||||
|
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
|
||||||
|
"execute code present on the Hub on your local machine."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpointing_steps",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume_from_checkpoint",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="If the training should continue from a checkpoint folder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--with_tracking",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to enable experiment trackers for logging.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report_to",
|
||||||
|
type=str,
|
||||||
|
default="tensorboard",
|
||||||
|
help=(
|
||||||
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||||
|
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
|
||||||
|
"Only applicable when `--with_tracking` is passed."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--low_cpu_mem_usage",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
|
||||||
|
"If passed, LLM loading time and RAM consumption will be benefited."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
##########################
|
||||||
|
# Generation Config #
|
||||||
|
##########################
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument("--k", type=int, default=40, help="Choose k candidate words")
|
||||||
|
parser.add_argument("--p", type=float, default=0.95, help="The sum of probability of candidate words is 0.9 ")
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Exp Args #
|
||||||
|
##########################
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"The LoRA adapter checkpoint. Set None if you want to fine-tune from LoftQ."
|
||||||
|
"Specify a path if you want to evaluate."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Sanity checks
|
||||||
|
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||||
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||||
|
else:
|
||||||
|
if args.train_file is not None:
|
||||||
|
extension = args.train_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
|
||||||
|
if args.validation_file is not None:
|
||||||
|
extension = args.validation_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
||||||
|
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
||||||
|
send_example_telemetry("run_clm_no_trainer", args)
|
||||||
|
|
||||||
|
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||||
|
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
|
||||||
|
# in the environment
|
||||||
|
accelerator_log_kwargs = {}
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator_log_kwargs["log_with"] = args.report_to
|
||||||
|
accelerator_log_kwargs["project_dir"] = args.output_dir
|
||||||
|
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# If passed along, set the training seed now.
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if args.push_to_hub:
|
||||||
|
# Retrieve of infer repo_name
|
||||||
|
repo_name = args.hub_model_id
|
||||||
|
if repo_name is None:
|
||||||
|
repo_name = Path(args.output_dir).absolute().name
|
||||||
|
# Create repo and retrieve repo_id
|
||||||
|
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
|
||||||
|
# Clone repo locally
|
||||||
|
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
|
||||||
|
|
||||||
|
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||||
|
if "step_*" not in gitignore:
|
||||||
|
gitignore.write("step_*\n")
|
||||||
|
if "epoch_*" not in gitignore:
|
||||||
|
gitignore.write("epoch_*\n")
|
||||||
|
elif args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||||
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||||
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
|
#
|
||||||
|
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||||
|
# 'text' is found. You can easily tweak this behavior (see below).
|
||||||
|
#
|
||||||
|
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||||
|
# download the dataset.
|
||||||
|
if args.dataset_name is not None:
|
||||||
|
# Downloading and loading a dataset from the hub.
|
||||||
|
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||||
|
if "validation" not in raw_datasets.keys():
|
||||||
|
raw_datasets["validation"] = load_dataset(
|
||||||
|
args.dataset_name,
|
||||||
|
args.dataset_config_name,
|
||||||
|
split=f"train[:{args.validation_split_percentage}%]",
|
||||||
|
)
|
||||||
|
raw_datasets["train"] = load_dataset(
|
||||||
|
args.dataset_name,
|
||||||
|
args.dataset_config_name,
|
||||||
|
split=f"train[{args.validation_split_percentage}%:]",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_files = {}
|
||||||
|
dataset_args = {}
|
||||||
|
if args.train_file is not None:
|
||||||
|
data_files["train"] = args.train_file
|
||||||
|
if args.validation_file is not None:
|
||||||
|
data_files["validation"] = args.validation_file
|
||||||
|
extension = args.train_file.split(".")[-1]
|
||||||
|
if extension == "txt":
|
||||||
|
extension = "text"
|
||||||
|
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
||||||
|
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
||||||
|
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||||
|
if "validation" not in raw_datasets.keys():
|
||||||
|
raw_datasets["validation"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[:{args.validation_split_percentage}%]",
|
||||||
|
**dataset_args,
|
||||||
|
)
|
||||||
|
raw_datasets["train"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[{args.validation_split_percentage}%:]",
|
||||||
|
**dataset_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
|
# Load pretrained model and tokenizer
|
||||||
|
#
|
||||||
|
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||||
|
# download model & vocab.
|
||||||
|
if args.config_name:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
args.config_name,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
)
|
||||||
|
elif args.model_name_or_path:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = CONFIG_MAPPING[args.model_type]()
|
||||||
|
logger.warning("You are instantiating a new config instance from scratch.")
|
||||||
|
|
||||||
|
if args.tokenizer_name:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code
|
||||||
|
)
|
||||||
|
elif args.model_name_or_path:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
use_fast=not args.use_slow_tokenizer,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||||
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||||
|
)
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Tokenizer #
|
||||||
|
##########################
|
||||||
|
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
||||||
|
tokenizer.padding_side = "left" # Allow batched inference
|
||||||
|
tokenizer.truncation_side = "left"
|
||||||
|
|
||||||
|
if args.model_name_or_path:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=False,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=config.torch_dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Training new model from scratch")
|
||||||
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=args.trust_remote_code)
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Peft Model #
|
||||||
|
##########################
|
||||||
|
if args.adapter_name_or_path is None:
|
||||||
|
model = PeftModel.from_pretrained(model, args.model_name_or_path, subfolder="loftq_init", is_trainable=True)
|
||||||
|
else:
|
||||||
|
model = PeftModel.from_pretrained(model, args.adapter_name_or_path, is_trainable=True)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||||
|
# on a small vocab and want a smaller embedding size, remove this test.
|
||||||
|
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||||
|
if len(tokenizer) > embedding_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
# Preprocessing the datasets.
|
||||||
|
# First we tokenize all the texts.
|
||||||
|
##########################
|
||||||
|
# GSM8K dataset #
|
||||||
|
##########################
|
||||||
|
|
||||||
|
# Preprocessing the datasets.
|
||||||
|
# First we tokenize all the texts.
|
||||||
|
column_names = raw_datasets["train"].column_names
|
||||||
|
|
||||||
|
# Get the column names for source/target.
|
||||||
|
source_column, target_column = "question", "answer"
|
||||||
|
|
||||||
|
# Temporarily set max_target_length for training.
|
||||||
|
padding = "max_length" if args.pad_to_max_length else False
|
||||||
|
task_prompt = "\nAnswer the above question. First think step by step and then answer the final number.\n"
|
||||||
|
|
||||||
|
def prompt_process(sent_1, sent_2, prompt_1="", prompt_2="", prompt_3=""):
|
||||||
|
sent_2 = sent_2.replace("####", "The final answer is")
|
||||||
|
return prompt_1 + sent_1 + prompt_2 + sent_2 + prompt_3
|
||||||
|
|
||||||
|
def preprocess_function_train(examples):
|
||||||
|
sources = examples[source_column]
|
||||||
|
targets = examples[target_column]
|
||||||
|
|
||||||
|
inputs = [prompt_process(source, target, prompt_2=task_prompt) for (source, target) in zip(sources, targets)]
|
||||||
|
|
||||||
|
model_inputs = tokenizer(
|
||||||
|
inputs,
|
||||||
|
max_length=args.max_source_length + args.max_target_length,
|
||||||
|
padding=padding,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = copy.deepcopy(model_inputs)
|
||||||
|
|
||||||
|
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||||
|
# padding in the loss.
|
||||||
|
if padding == "max_length" and args.ignore_pad_token_for_loss:
|
||||||
|
# get the length of the target tokens. -1 to kick out the <BOS> token
|
||||||
|
target_tokens = tokenizer(targets, padding=False)
|
||||||
|
target_len = [len(label) - 1 for label in target_tokens["input_ids"]]
|
||||||
|
|
||||||
|
# don't calculate the loss from source and padding (left padding)
|
||||||
|
for i in range(len(labels["input_ids"])):
|
||||||
|
labels["input_ids"][i, : -target_len[i]] = -100
|
||||||
|
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def preprocess_function_test(examples):
|
||||||
|
sources = examples[source_column]
|
||||||
|
labels = examples[target_column]
|
||||||
|
|
||||||
|
inputs = [source + task_prompt for source in sources]
|
||||||
|
|
||||||
|
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
|
||||||
|
labels = tokenizer(labels, max_length=args.max_target_length, padding=padding, truncation=True)
|
||||||
|
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
with accelerator.main_process_first():
|
||||||
|
train_dataset = raw_datasets["train"].map(
|
||||||
|
preprocess_function_train,
|
||||||
|
batched=True,
|
||||||
|
num_proc=args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on training dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = raw_datasets["test"].map(
|
||||||
|
preprocess_function_test,
|
||||||
|
batched=True,
|
||||||
|
num_proc=args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on test dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log a few random samples from the set:
|
||||||
|
for index in random.sample(range(len(train_dataset)), 2):
|
||||||
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||||
|
for index in random.sample(range(len(eval_dataset)), 2):
|
||||||
|
logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.")
|
||||||
|
|
||||||
|
# DataLoaders creation:
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
|
||||||
|
)
|
||||||
|
eval_dataloader = DataLoader(
|
||||||
|
eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
# Split weights in two groups, one with weight decay and the other not.
|
||||||
|
no_decay = ["bias", "layer_norm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and "lora" in n],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||||
|
|
||||||
|
# Scheduler and math around the number of training steps.
|
||||||
|
overrode_max_train_steps = False
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if args.max_train_steps is None:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
overrode_max_train_steps = True
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
name=args.lr_scheduler_type,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
|
||||||
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||||
|
if accelerator.distributed_type == DistributedType.TPU:
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if overrode_max_train_steps:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
# Afterwards we recalculate our number of training epochs
|
||||||
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# Figure out how many steps we should save the Accelerator states
|
||||||
|
checkpointing_steps = args.checkpointing_steps
|
||||||
|
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
||||||
|
checkpointing_steps = int(checkpointing_steps)
|
||||||
|
|
||||||
|
# We need to initialize the trackers we use, and also store our configuration.
|
||||||
|
# The trackers initializes automatically on the main process.
|
||||||
|
if args.with_tracking:
|
||||||
|
experiment_config = vars(args)
|
||||||
|
# TensorBoard cannot log Enums, need the raw value
|
||||||
|
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
|
||||||
|
accelerator.init_trackers("clm_no_trainer", experiment_config)
|
||||||
|
|
||||||
|
# Train!
|
||||||
|
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(f" Num examples = {len(train_dataset)}")
|
||||||
|
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||||
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||||
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
|
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||||
|
# Only show the progress bar once on each machine.
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||||
|
completed_steps = 0
|
||||||
|
starting_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
|
if args.resume_from_checkpoint:
|
||||||
|
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||||
|
checkpoint_path = args.resume_from_checkpoint
|
||||||
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
else:
|
||||||
|
# Get the most recent checkpoint
|
||||||
|
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
||||||
|
dirs.sort(key=os.path.getctime)
|
||||||
|
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||||
|
checkpoint_path = path
|
||||||
|
path = os.path.basename(checkpoint_path)
|
||||||
|
|
||||||
|
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
|
||||||
|
accelerator.load_state(path)
|
||||||
|
# Extract `epoch_{i}` or `step_{i}`
|
||||||
|
training_difference = os.path.splitext(path)[0]
|
||||||
|
|
||||||
|
if "epoch" in training_difference:
|
||||||
|
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
|
||||||
|
resume_step = None
|
||||||
|
completed_steps = starting_epoch * num_update_steps_per_epoch
|
||||||
|
else:
|
||||||
|
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
||||||
|
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
|
||||||
|
starting_epoch = resume_step // len(train_dataloader)
|
||||||
|
resume_step -= starting_epoch * len(train_dataloader)
|
||||||
|
completed_steps = resume_step // args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
# update the progress_bar if load from checkpoint
|
||||||
|
progress_bar.update(completed_steps)
|
||||||
|
|
||||||
|
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||||
|
model.train()
|
||||||
|
if args.with_tracking:
|
||||||
|
total_loss = 0
|
||||||
|
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
|
||||||
|
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
|
||||||
|
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||||
|
else:
|
||||||
|
active_dataloader = train_dataloader
|
||||||
|
for step, batch in enumerate(active_dataloader):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs.loss
|
||||||
|
# We keep track of the loss at each epoch
|
||||||
|
if args.with_tracking:
|
||||||
|
total_loss += loss.detach().float()
|
||||||
|
accelerator.backward(loss)
|
||||||
|
accelerator.print(f"Epoch: {epoch} | Step: {step} | Loss: {loss}")
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
completed_steps += 1
|
||||||
|
|
||||||
|
if isinstance(checkpointing_steps, int):
|
||||||
|
if completed_steps % checkpointing_steps == 0:
|
||||||
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
if completed_steps >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_new_tokens": args.max_target_length,
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"top_k": args.k,
|
||||||
|
"top_p": args.p,
|
||||||
|
"do_sample": True,
|
||||||
|
}
|
||||||
|
ans_pred_list = []
|
||||||
|
ans_gold_list = []
|
||||||
|
for step, batch in enumerate(eval_dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
gen_kwargs["input_ids"] = batch["input_ids"]
|
||||||
|
gen_kwargs["attention_mask"] = batch["attention_mask"]
|
||||||
|
generated_tokens = accelerator.unwrap_model(model).generate(**gen_kwargs)
|
||||||
|
|
||||||
|
pred_tokens = generated_tokens[:, args.max_source_length :]
|
||||||
|
pred_tokens = accelerator.pad_across_processes(pred_tokens, dim=1, pad_index=tokenizer.pad_token_id)
|
||||||
|
gold_tokens = batch["labels"]
|
||||||
|
|
||||||
|
if not args.pad_to_max_length:
|
||||||
|
# If we did not pad to max length, we need to pad the labels too
|
||||||
|
gold_tokens = accelerator.pad_across_processes(
|
||||||
|
batch["labels"], dim=1, pad_index=tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_tokens, gold_tokens = accelerator.gather_for_metrics((pred_tokens, gold_tokens))
|
||||||
|
pred_tokens, gold_tokens = pred_tokens.cpu().numpy(), gold_tokens.cpu().numpy()
|
||||||
|
|
||||||
|
if isinstance(pred_tokens, tuple):
|
||||||
|
pred_tokens = pred_tokens[0]
|
||||||
|
decoded_pred = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)
|
||||||
|
decoded_gold = tokenizer.batch_decode(gold_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Extract the numbers in sentences
|
||||||
|
accelerator.print(decoded_pred)
|
||||||
|
ans_pred_list += [extract_answer_number(sentence_pred) for sentence_pred in decoded_pred]
|
||||||
|
ans_gold_list += [extract_answer_number(sentence_gold) for sentence_gold in decoded_gold]
|
||||||
|
|
||||||
|
accelerator.print(ans_pred_list)
|
||||||
|
accelerator.print(ans_gold_list)
|
||||||
|
accuracy = compute_accuracy(ans_gold_list, ans_pred_list)
|
||||||
|
|
||||||
|
logger.info(f"epoch {epoch}: accuracy: {accuracy}")
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator.log(
|
||||||
|
{
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"train_loss": total_loss.item() / len(train_dataloader),
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": completed_steps,
|
||||||
|
},
|
||||||
|
step=completed_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
|
||||||
|
)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
repo.push_to_hub(
|
||||||
|
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.checkpointing_steps == "epoch":
|
||||||
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.output_dir is not None:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
|
||||||
|
)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
if args.push_to_hub:
|
||||||
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
|
||||||
|
|
||||||
|
PATTERN_NUMBER = re.compile(r"-?\d+\.?\d*")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer_number(sentence: str) -> float:
|
||||||
|
sentence = sentence.replace(",", "")
|
||||||
|
pred = PATTERN_NUMBER.findall(sentence)
|
||||||
|
if not pred:
|
||||||
|
return float("inf")
|
||||||
|
segment = sentence.split("The final answer is ")
|
||||||
|
if len(segment) > 1:
|
||||||
|
pred_answer = segment[1]
|
||||||
|
pred_answer = PATTERN_NUMBER.findall(pred_answer)
|
||||||
|
if len(pred_answer) > 0:
|
||||||
|
pred_answer = pred_answer[0]
|
||||||
|
else:
|
||||||
|
pred_answer = float(pred[-1])
|
||||||
|
else:
|
||||||
|
pred_answer = float(pred[-1])
|
||||||
|
|
||||||
|
if isinstance(pred_answer, str):
|
||||||
|
try:
|
||||||
|
pred_answer = float(pred_answer)
|
||||||
|
except ValueError:
|
||||||
|
pred_answer = float("inf")
|
||||||
|
return pred_answer
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy(pred: list, gold: list):
|
||||||
|
acc = 0.0
|
||||||
|
for p, g in zip(pred, gold):
|
||||||
|
if p == g:
|
||||||
|
acc += 1
|
||||||
|
|
||||||
|
return acc / len(pred)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
# example command
|
||||||
|
|
||||||
|
# python train_gsm8k_llama.py \
|
||||||
|
# --model_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-backbone \
|
||||||
|
# --adapter_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-adapters \
|
||||||
|
# --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
|
||||||
|
# --learning_rate 1e-4 \
|
||||||
|
# --seed 202 \
|
||||||
|
# --dataset_name gsm8k \
|
||||||
|
# --dataset_config main \
|
||||||
|
# --pad_to_max_length \
|
||||||
|
# --max_source_length 128 \
|
||||||
|
# --max_target_length 256 \
|
||||||
|
# --num_train_epochs 5 \
|
||||||
|
# --per_device_train_batch_size 4 \
|
||||||
|
# --per_device_eval_batch_size 4 \
|
||||||
|
# --gradient_accumulation_steps 4 \
|
||||||
|
# --with_tracking \
|
||||||
|
# --report_to tensorboard
|
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
accelerate
|
||||||
|
torch
|
||||||
|
safetensors
|
||||||
|
bitsandbytes
|
||||||
|
scipy
|
||||||
|
peft
|
||||||
|
transformers
|
||||||
|
tqdm
|
||||||
|
packaging
|
||||||
|
pytest
|
||||||
|
numpy
|
||||||
|
pyyaml
|
||||||
|
datasets
|
||||||
|
psutil
|
||||||
|
setuptools
|
@ -48,6 +48,7 @@ from .tuners import (
|
|||||||
AdaptionPromptConfig,
|
AdaptionPromptConfig,
|
||||||
AdaptionPromptModel,
|
AdaptionPromptModel,
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
|
LoftQConfig,
|
||||||
LoraModel,
|
LoraModel,
|
||||||
LoHaConfig,
|
LoHaConfig,
|
||||||
LoHaModel,
|
LoHaModel,
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
|
from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
|
||||||
from .lora import LoraConfig, LoraModel
|
from .lora import LoraConfig, LoraModel, LoftQConfig
|
||||||
from .loha import LoHaConfig, LoHaModel
|
from .loha import LoHaConfig, LoHaModel
|
||||||
from .lokr import LoKrConfig, LoKrModel
|
from .lokr import LoKrConfig, LoKrModel
|
||||||
from .ia3 import IA3Config, IA3Model
|
from .ia3 import IA3Config, IA3Model
|
||||||
|
@ -15,13 +15,13 @@
|
|||||||
|
|
||||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||||
|
|
||||||
from .config import LoraConfig
|
from .config import LoftQConfig, LoraConfig
|
||||||
from .gptq import QuantLinear
|
from .gptq import QuantLinear
|
||||||
from .layer import Conv2d, Embedding, Linear, LoraLayer
|
from .layer import Conv2d, Embedding, Linear, LoraLayer
|
||||||
from .model import LoraModel
|
from .model import LoraModel
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["LoraConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"]
|
__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"]
|
||||||
|
|
||||||
|
|
||||||
if is_bnb_available():
|
if is_bnb_available():
|
||||||
|
@ -22,6 +22,25 @@ from peft.config import PeftConfig
|
|||||||
from peft.utils import PeftType
|
from peft.utils import PeftType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoftQConfig:
|
||||||
|
"""
|
||||||
|
This is the sub-configuration class to store the configuration of a [`LoraModel`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bits_pattern (`dict`): The mapping from layer names or regexp expression to bits which are different from the
|
||||||
|
default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}.
|
||||||
|
bits (`int`): Quantization bits for LoftQ.
|
||||||
|
iter (`int`): Alternating iterations for LoftQ.
|
||||||
|
fake (`bool`): True: use fp16/fp32; used for first time to save weights. False: use bitsandbytes 4bit linear
|
||||||
|
models. weights can't be saved. Recommend to set to True, save the weights and load the saved weights in 4
|
||||||
|
bits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"})
|
||||||
|
loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoraConfig(PeftConfig):
|
class LoraConfig(PeftConfig):
|
||||||
"""
|
"""
|
||||||
@ -78,7 +97,7 @@ class LoraConfig(PeftConfig):
|
|||||||
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
|
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
init_lora_weights: bool | Literal["gaussian"] = field(
|
init_lora_weights: bool | Literal["gaussian", "loftq"] = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@ -86,6 +105,7 @@ class LoraConfig(PeftConfig):
|
|||||||
"initialization from the reference implementation from Microsoft. Passing 'gaussian' results "
|
"initialization from the reference implementation from Microsoft. Passing 'gaussian' results "
|
||||||
"in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
|
"in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
|
||||||
"to False leads to completely random initialization and is discouraged."
|
"to False leads to completely random initialization and is discouraged."
|
||||||
|
"Pass `'loftq'` to use LoftQ initialization"
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -121,6 +141,16 @@ class LoraConfig(PeftConfig):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# dict type is used when loading config.json
|
||||||
|
loftq_config: Union[LoftQConfig, dict] = field(
|
||||||
|
default_factory=dict,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone "
|
||||||
|
"weights and initialize Lora layers."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.peft_type = PeftType.LORA
|
self.peft_type = PeftType.LORA
|
||||||
@ -134,3 +164,16 @@ class LoraConfig(PeftConfig):
|
|||||||
# if target_modules is a regex expression, then layers_pattern should be None
|
# if target_modules is a regex expression, then layers_pattern should be None
|
||||||
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
|
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
|
||||||
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")
|
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")
|
||||||
|
|
||||||
|
# handle init_lora_weights and loftq_config
|
||||||
|
if self.init_lora_weights == "loftq":
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
if not importlib.util.find_spec("scipy"):
|
||||||
|
raise ImportError("The required package 'scipy' is not installed. Please install it to continue.")
|
||||||
|
if self.loftq_config is None:
|
||||||
|
raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")
|
||||||
|
|
||||||
|
# convert loftq_config to dict
|
||||||
|
if self.loftq_config is not None and not isinstance(self.loftq_config, dict):
|
||||||
|
self.loftq_config = vars(self.loftq_config)
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -46,6 +46,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
# Mark the weight as unmerged
|
# Mark the weight as unmerged
|
||||||
self._disable_adapters = False
|
self._disable_adapters = False
|
||||||
self.merged_adapters = []
|
self.merged_adapters = []
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
base_layer = self.get_base_layer()
|
base_layer = self.get_base_layer()
|
||||||
if isinstance(base_layer, nn.Linear):
|
if isinstance(base_layer, nn.Linear):
|
||||||
@ -83,7 +84,10 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
|
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
|
||||||
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
|
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
|
||||||
self.scaling[adapter_name] = lora_alpha / r
|
self.scaling[adapter_name] = lora_alpha / r
|
||||||
if init_lora_weights:
|
|
||||||
|
if init_lora_weights == "loftq":
|
||||||
|
self.loftq_init(adapter_name)
|
||||||
|
elif init_lora_weights:
|
||||||
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
||||||
|
|
||||||
weight = getattr(self.get_base_layer(), "weight", None)
|
weight = getattr(self.get_base_layer(), "weight", None)
|
||||||
@ -115,7 +119,10 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
|
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
|
||||||
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
|
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
|
||||||
self.scaling[adapter_name] = lora_alpha / r
|
self.scaling[adapter_name] = lora_alpha / r
|
||||||
if init_lora_weights:
|
|
||||||
|
if init_lora_weights == "loftq":
|
||||||
|
self.loftq_init(adapter_name)
|
||||||
|
elif init_lora_weights:
|
||||||
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
||||||
|
|
||||||
weight = getattr(base_layer, "weight", None)
|
weight = getattr(base_layer, "weight", None)
|
||||||
@ -142,6 +149,10 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
|
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
|
||||||
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
|
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
|
||||||
self.scaling[adapter_name] = lora_alpha / r
|
self.scaling[adapter_name] = lora_alpha / r
|
||||||
|
|
||||||
|
if init_lora_weights == "loftq":
|
||||||
|
self.loftq_init(adapter_name)
|
||||||
|
elif init_lora_weights:
|
||||||
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
||||||
|
|
||||||
base_layer = self.get_base_layer()
|
base_layer = self.get_base_layer()
|
||||||
@ -170,6 +181,27 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
nn.init.zeros_(self.lora_embedding_A[adapter_name])
|
nn.init.zeros_(self.lora_embedding_A[adapter_name])
|
||||||
nn.init.normal_(self.lora_embedding_B[adapter_name])
|
nn.init.normal_(self.lora_embedding_B[adapter_name])
|
||||||
|
|
||||||
|
def loftq_init(self, adapter_name):
|
||||||
|
from peft.utils.loftq_utils import loftq_init
|
||||||
|
|
||||||
|
weight = self.get_base_layer().weight
|
||||||
|
kwargs = {
|
||||||
|
"num_bits": self.kwargs.get("loftq_bits", 4),
|
||||||
|
"reduced_rank": self.r[adapter_name],
|
||||||
|
"num_iter": self.kwargs.get("loftq_iter", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
qweight, lora_A, lora_B = loftq_init(weight, **kwargs)
|
||||||
|
if adapter_name in self.lora_A.keys():
|
||||||
|
# initialize A the same way as the default for nn.Linear and B to zero
|
||||||
|
self.lora_A[adapter_name].weight.data = lora_A
|
||||||
|
self.lora_B[adapter_name].weight.data = lora_B
|
||||||
|
if adapter_name in self.lora_embedding_A.keys():
|
||||||
|
# initialize a the same way as the default for nn.linear and b to zero
|
||||||
|
self.lora_embedding_A[adapter_name].weight.data = lora_A
|
||||||
|
self.lora_embedding_B[adapter_name].weight.data = lora_B
|
||||||
|
self.get_base_layer().weight.data = qweight
|
||||||
|
|
||||||
def set_scale(self, adapter, scale):
|
def set_scale(self, adapter, scale):
|
||||||
if adapter not in self.scaling:
|
if adapter not in self.scaling:
|
||||||
# Ignore the case where the adapter is not in the layer
|
# Ignore the case where the adapter is not in the layer
|
||||||
@ -218,11 +250,11 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
lora_dropout: float = 0.0,
|
lora_dropout: float = 0.0,
|
||||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||||
is_target_conv_1d_layer: bool = False,
|
is_target_conv_1d_layer: bool = False,
|
||||||
init_lora_weights: bool = True,
|
init_lora_weights: Union[bool, str] = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
LoraLayer.__init__(self, base_layer)
|
LoraLayer.__init__(self, base_layer, **kwargs)
|
||||||
self.fan_in_fan_out = fan_in_fan_out
|
self.fan_in_fan_out = fan_in_fan_out
|
||||||
|
|
||||||
self._active_adapter = adapter_name
|
self._active_adapter = adapter_name
|
||||||
@ -351,7 +383,7 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
r: int = 0,
|
r: int = 0,
|
||||||
lora_alpha: int = 1,
|
lora_alpha: int = 1,
|
||||||
lora_dropout: float = 0.0,
|
lora_dropout: float = 0.0,
|
||||||
init_lora_weights: bool = True,
|
init_lora_weights: Union[bool, str] = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -491,7 +523,7 @@ class Conv2d(nn.Module, LoraLayer):
|
|||||||
r: int = 0,
|
r: int = 0,
|
||||||
lora_alpha: int = 1,
|
lora_alpha: int = 1,
|
||||||
lora_dropout: float = 0.0,
|
lora_dropout: float = 0.0,
|
||||||
init_lora_weights: bool = True,
|
init_lora_weights: Union[bool, str] = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -286,8 +286,10 @@ class LoraModel(BaseTuner):
|
|||||||
elif isinstance(target_base_layer, torch.nn.Embedding):
|
elif isinstance(target_base_layer, torch.nn.Embedding):
|
||||||
embedding_kwargs = kwargs.copy()
|
embedding_kwargs = kwargs.copy()
|
||||||
embedding_kwargs.pop("fan_in_fan_out", None)
|
embedding_kwargs.pop("fan_in_fan_out", None)
|
||||||
|
embedding_kwargs.update(lora_config.loftq_config)
|
||||||
new_module = Embedding(target, adapter_name, **embedding_kwargs)
|
new_module = Embedding(target, adapter_name, **embedding_kwargs)
|
||||||
elif isinstance(target_base_layer, torch.nn.Conv2d):
|
elif isinstance(target_base_layer, torch.nn.Conv2d):
|
||||||
|
kwargs.update(lora_config.loftq_config)
|
||||||
new_module = Conv2d(target, adapter_name, **kwargs)
|
new_module = Conv2d(target, adapter_name, **kwargs)
|
||||||
elif isinstance(target_base_layer, torch.nn.Linear):
|
elif isinstance(target_base_layer, torch.nn.Linear):
|
||||||
if kwargs["fan_in_fan_out"]:
|
if kwargs["fan_in_fan_out"]:
|
||||||
@ -296,6 +298,7 @@ class LoraModel(BaseTuner):
|
|||||||
"Setting fan_in_fan_out to False."
|
"Setting fan_in_fan_out to False."
|
||||||
)
|
)
|
||||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||||
|
kwargs.update(lora_config.loftq_config)
|
||||||
new_module = Linear(target, adapter_name, **kwargs)
|
new_module = Linear(target, adapter_name, **kwargs)
|
||||||
elif isinstance(target_base_layer, Conv1D):
|
elif isinstance(target_base_layer, Conv1D):
|
||||||
if not kwargs["fan_in_fan_out"]:
|
if not kwargs["fan_in_fan_out"]:
|
||||||
@ -304,6 +307,7 @@ class LoraModel(BaseTuner):
|
|||||||
"Setting fan_in_fan_out to True."
|
"Setting fan_in_fan_out to True."
|
||||||
)
|
)
|
||||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
||||||
|
kwargs.update(lora_config.loftq_config)
|
||||||
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
|
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
227
src/peft/utils/loftq_utils.py
Normal file
227
src/peft/utils/loftq_utils.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Reference code: https://github.com/yxli2123/LoftQ/blob/main/utils.py
|
||||||
|
# Reference paper: https://arxiv.org/abs/2310.08659
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_bnb_available():
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
|
||||||
|
class NFQuantizer:
|
||||||
|
def __init__(self, num_bits=2, device="cuda", method="normal", block_size=64, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_bits = num_bits
|
||||||
|
self.device = device
|
||||||
|
self.method = method
|
||||||
|
self.block_size = block_size
|
||||||
|
if self.method == "normal":
|
||||||
|
self.norm_lookup_table = self.create_normal_map(num_bits=self.num_bits)
|
||||||
|
self.norm_lookup_table = self.norm_lookup_table.to(device)
|
||||||
|
elif self.method == "uniform":
|
||||||
|
self.norm_lookup_table = self.create_uniform_map(num_bits=self.num_bits)
|
||||||
|
self.norm_lookup_table = self.norm_lookup_table.to(device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Other quantization methods not supported yet.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_uniform_map(symmetric=False, num_bits=4):
|
||||||
|
if symmetric:
|
||||||
|
# print("symmetric uniform quantization")
|
||||||
|
negative = torch.linspace(-1, 0, 2 ** (num_bits - 1))
|
||||||
|
positive = torch.linspace(0, 1, 2 ** (num_bits - 1))
|
||||||
|
table = torch.cat([negative, positive[1:]])
|
||||||
|
else:
|
||||||
|
# print("asymmetric uniform quantization")
|
||||||
|
table = torch.linspace(-1, 1, 2**num_bits)
|
||||||
|
return table
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_normal_map(offset=0.9677083, symmetric=False, num_bits=2):
|
||||||
|
try:
|
||||||
|
from scipy.stats import norm
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The required package 'scipy' is not installed. Please install it to continue.")
|
||||||
|
|
||||||
|
variations = 2**num_bits
|
||||||
|
if symmetric:
|
||||||
|
v = norm.ppf(torch.linspace(1 - offset, offset, variations + 1)).tolist()
|
||||||
|
values = []
|
||||||
|
for index in range(len(v) - 1):
|
||||||
|
values.append(0.5 * v[index] + 0.5 * v[index + 1])
|
||||||
|
v = values
|
||||||
|
else:
|
||||||
|
# one more positive value, this is an asymmetric type
|
||||||
|
v1 = norm.ppf(torch.linspace(offset, 0.5, variations // 2 + 1)[:-1]).tolist()
|
||||||
|
v2 = [0]
|
||||||
|
v3 = (-norm.ppf(torch.linspace(offset, 0.5, variations // 2)[:-1])).tolist()
|
||||||
|
v = v1 + v2 + v3
|
||||||
|
|
||||||
|
values = torch.Tensor(v)
|
||||||
|
values = values.sort().values
|
||||||
|
values /= values.max()
|
||||||
|
return values
|
||||||
|
|
||||||
|
def quantize_tensor(self, weight):
|
||||||
|
max_abs = torch.abs(weight).max()
|
||||||
|
weight_normed = weight / max_abs
|
||||||
|
|
||||||
|
weight_normed_expanded = weight_normed.unsqueeze(-1)
|
||||||
|
|
||||||
|
# Reshape L to have the same number of dimensions as X_expanded
|
||||||
|
L_reshaped = torch.tensor(self.norm_lookup_table).reshape(1, -1)
|
||||||
|
|
||||||
|
# Calculate the absolute difference between X_expanded and L_reshaped
|
||||||
|
abs_diff = torch.abs(weight_normed_expanded - L_reshaped)
|
||||||
|
|
||||||
|
# Find the index of the minimum absolute difference for each element
|
||||||
|
qweight = torch.argmin(abs_diff, dim=-1)
|
||||||
|
return qweight, max_abs
|
||||||
|
|
||||||
|
def dequantize_tensor(self, qweight, max_abs):
|
||||||
|
qweight_flatten = qweight.flatten()
|
||||||
|
|
||||||
|
weight_normed = self.norm_lookup_table[qweight_flatten]
|
||||||
|
weight = weight_normed * max_abs
|
||||||
|
|
||||||
|
weight = weight.reshape(qweight.shape)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def quantize_block(self, weight):
|
||||||
|
if len(weight.shape) != 2:
|
||||||
|
raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.")
|
||||||
|
if weight.shape[0] * weight.shape[1] % self.block_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) "
|
||||||
|
f"is not dividable by block size {self.block_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
M, N = weight.shape
|
||||||
|
device = weight.device
|
||||||
|
|
||||||
|
# Quantization
|
||||||
|
weight_flatten = weight.flatten() # (M*N, )
|
||||||
|
weight_block = weight_flatten.reshape(-1, self.block_size) # (L, B), L = M * N / B
|
||||||
|
if self.method == "normal":
|
||||||
|
weight_max = weight_block.abs().max(dim=-1)[0] # (L, 1)
|
||||||
|
elif self.method == "uniform":
|
||||||
|
weight_max = weight_block.mean(dim=-1) + 2.5 * weight_block.std(dim=-1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Method not supported yet.")
|
||||||
|
weight_max = weight_max.unsqueeze(-1)
|
||||||
|
weight_divabs = weight_block / weight_max # (L, B)
|
||||||
|
weight_divabs = weight_divabs.unsqueeze(-1) # (L, B, 1)
|
||||||
|
L_reshaped = self.norm_lookup_table.reshape(1, -1) # (1, 2**K)
|
||||||
|
|
||||||
|
abs_diff = torch.abs(weight_divabs - L_reshaped) # (L, B, 2**K)
|
||||||
|
qweight = torch.argmin(abs_diff, dim=-1) # (L, B)
|
||||||
|
|
||||||
|
# Pack multiple k-bit into uint8
|
||||||
|
qweight = qweight.reshape(-1, 8 // self.num_bits)
|
||||||
|
qweight_pack = torch.zeros((M * N // 8 * self.num_bits, 1), dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
|
# data format example:
|
||||||
|
# [1, 0, 3, 2] or [01, 00, 11, 10] -> [10110001], LIFO
|
||||||
|
for i in range(8 // self.num_bits):
|
||||||
|
qweight[:, i] = qweight[:, i] << i * self.num_bits
|
||||||
|
qweight_pack[:, 0] |= qweight[:, i]
|
||||||
|
|
||||||
|
return qweight_pack, weight_max, weight.shape
|
||||||
|
|
||||||
|
def dequantize_block(self, qweight, weight_max, weight_shape):
|
||||||
|
# unpack weight
|
||||||
|
device = qweight.device
|
||||||
|
weight = torch.zeros((qweight.shape[0], 8 // self.num_bits), dtype=torch.float32, device=device)
|
||||||
|
for i in range(8 // self.num_bits):
|
||||||
|
lookup_table_idx = qweight.to(torch.long) % 2**self.num_bits # get the most right 2 bits
|
||||||
|
lookup_table_idx = lookup_table_idx.to(torch.int)
|
||||||
|
weight[:, i] = self.norm_lookup_table[lookup_table_idx].squeeze()
|
||||||
|
qweight = qweight >> self.num_bits # right shift 2 bits of the original data
|
||||||
|
|
||||||
|
weight_block = weight.reshape(-1, self.block_size)
|
||||||
|
weight = weight_block * weight_max
|
||||||
|
weight = weight.reshape(weight_shape)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def _low_rank_decomposition(weight, reduced_rank=32):
|
||||||
|
"""
|
||||||
|
:param weight: The matrix to decompose, of shape (H, W) :param reduced_rank: the final rank :return:
|
||||||
|
"""
|
||||||
|
matrix_dimension = len(weight.size())
|
||||||
|
if matrix_dimension != 2:
|
||||||
|
raise ValueError(f"Only support 2D matrix, but your input has {matrix_dimension} dimensions.")
|
||||||
|
|
||||||
|
# Use SVD to decompose a matrix, default full_matrices is False to save parameters
|
||||||
|
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||||
|
|
||||||
|
L = U @ (torch.sqrt(torch.diag(S)[:, 0:reduced_rank]))
|
||||||
|
R = torch.sqrt(torch.diag(S)[0:reduced_rank, :]) @ Vh
|
||||||
|
|
||||||
|
return {"L": L, "R": R, "U": U, "S": S, "Vh": Vh, "reduced_rank": reduced_rank}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, reduced_rank: int, num_iter=1):
|
||||||
|
if num_bits not in [2, 4, 8]:
|
||||||
|
raise ValueError("Only support 2, 4, 8 bits quantization")
|
||||||
|
if num_iter <= 0:
|
||||||
|
raise ValueError("Number of iterations must be greater than 0")
|
||||||
|
|
||||||
|
out_feature, in_feature = weight.size()
|
||||||
|
device = weight.device
|
||||||
|
dtype = weight.dtype
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} "
|
||||||
|
f"| Num Iter: {num_iter} | Num Bits: {num_bits}"
|
||||||
|
)
|
||||||
|
if not is_bnb_4bit_available():
|
||||||
|
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
|
||||||
|
|
||||||
|
weight = weight.to(torch.float32)
|
||||||
|
res = weight.clone()
|
||||||
|
for i in range(num_iter):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# Quantization
|
||||||
|
if num_bits == 4 and is_bnb_4bit_available():
|
||||||
|
qweight = bnb.nn.Params4bit(
|
||||||
|
res.to("cpu"), requires_grad=False, compress_statistics=False, quant_type="nf4"
|
||||||
|
).to(device)
|
||||||
|
dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state)
|
||||||
|
else:
|
||||||
|
quantized_weight, max_abs, shape = quantizer.quantize_block(res)
|
||||||
|
dequantized_weight = quantizer.dequantize_block(quantized_weight, max_abs, shape)
|
||||||
|
|
||||||
|
res = weight - dequantized_weight
|
||||||
|
|
||||||
|
# Decompose the residual by SVD
|
||||||
|
output = _low_rank_decomposition(res, reduced_rank=reduced_rank)
|
||||||
|
L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"]
|
||||||
|
res = weight - torch.mm(L, R)
|
||||||
|
|
||||||
|
lora_A, lora_B = R, L
|
||||||
|
|
||||||
|
return dequantized_weight.to(dtype), lora_A, lora_B
|
Reference in New Issue
Block a user