mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Randlora documentation and some example usage (#2524)
This is a follow up to #2464 and issue #2441. Entails documentation for RandLora and slightly updated example usage in the model.py docstring. Also adds RandLoRA to method comparison. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
@ -124,6 +124,8 @@
|
||||
title: Bone
|
||||
- local: package_reference/trainable_tokens
|
||||
title: Trainable Tokens
|
||||
- local: package_reference/randlora
|
||||
title: RandLora
|
||||
|
||||
title: Adapters
|
||||
- sections:
|
||||
|
45
docs/source/package_reference/randlora.md
Normal file
45
docs/source/package_reference/randlora.md
Normal file
@ -0,0 +1,45 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# RandLora: Full-rank parameter-efficient fine-tuning of large models
|
||||
[RandLora](https://huggingface.co/papers/2502.00987) is a parameter-efficient fine-tuning technique that is similar to [LoRA](https://huggingface.co/papers/2106.09685) and [VeRA](https://huggingface.co/papers/2310.11454) but performs full rank updates to improve performance. RandLora can be particulary usefull when adapting large model to hard tasks that require complex updates while preserving the parameter efficiency of LoRA. The full rank update of RandLora is achieved by linearly scaling random bases. The random bases are a collection of multiple low rank matrices such that the summation of their ranks if greater or equal to the full rank of the parameter matrices. The trainable parameters of RandLora are two diagonal matrices (vectors) that get multiplied with the right hand low rank random bases, in a similar way to VeRA's update. To maintain low memory usage, RandLora uses a custom function that prevents storing unnecessary bases in memory for backpropagation.
|
||||
|
||||
RandLora presents the noteworthy difference that contrary to other LoRA-like PEFT algorithm, increasing RandLora's random base ranks increases the amount of trainable parameters. Because number of bases x bases rank is constant in RandLora, reducing the rank will increase the number of random bases, hence the number of base-specific trainable diagonal bases.
|
||||
|
||||
Because reducing the rank of RandLora's random bases will increase their number, RandLora can become slower to train than LoRA for very small ranks where typically, ranks below 4 with result in a large training time increase. This does not affect inference though as the RandLora adapters can be merged into the pretrained weight matrices.
|
||||
|
||||
RandLora additionally supports training with sparse, ternary random bases (only containing -1, 0 and 1). These bases are as described in [Bingham et al.](https://cs-people.bu.edu/evimaria/cs565/kdd-rp.pdf) and [Ping et al.](https://hastie.su.domains/Papers/Ping/KDD06_rp.pdf) and could theoretically be used to reduce compute needs by performing aggregations instead of matrix multiplications to create the weight update. This is not currently supported. Although it does not currently reduce compute, using sparse random bases in RandLora can reduce overfitting in some cases. For users intersted in using sparse ternary bases, the `sparse` option is recommended over the `very_sparse` one that can reduce perfromance.
|
||||
|
||||
Similarly to VeRA, when saving the RandLora's parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default).
|
||||
|
||||
As in Vera and to handle different shapes of adapted layers, RandLora initializes shared A and B matrices with the largest required size for each dimension. During the forward pass, submatrices A and B for a given layer are sliced out from these shared matrices and used as described in the paper. For example, adapting two linear layers of shapes (100, 20) and (80, 50) will create A and B matrices of shapes (rank, 50) and (100, rank) respectively. Then, to adapt a layer of shape (100, 20), submatrices A and B of shapes (rank, 20) and (100, rank) will be extracted.
|
||||
|
||||
RandLora currently has the following constraint:
|
||||
|
||||
- Only `nn.Linear` layers are supported.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
> Low-Rank Adaptation (LoRA) and its variants have shown impressive results in reducing the number of trainable parameters and memory requirements of large transformer networks while maintaining fine-tuning performance. The low-rank nature of the weight update inherently limits the representation power of fine-tuned models, however, thus potentially compromising performance on complex tasks. This raises a critical question: when a performance gap between LoRA and standard fine-tuning is observed, is it due to the reduced number of trainable parameters or the rank deficiency?
|
||||
This paper aims to answer this question by introducing RandLora, a parameter-efficient method that performs full-rank updates using a learned linear combinations of low-rank, non-trainable random matrices. Our method limits the number of trainable parameters by restricting optimization to diagonal scaling matrices applied to the fixed random matrices. This allows us to effectively overcome the low-rank limitations while maintaining parameter and memory efficiency during training. Through extensive experimentation across vision, language, and vision-language benchmarks, we systematically evaluate the limitations of LoRA and existing random basis methods. Our findings reveal that full-rank updates are beneficial across vision and language tasks individually, and even more so for vision-language tasks, where RandLora significantly reduces---and sometimes eliminates---the performance gap between standard fine-tuning and LoRA, demonstrating its efficacy.
|
||||
|
||||
## RandLoraConfig
|
||||
|
||||
[[autodoc]] tuners.randlora.config.RandLoraConfig
|
||||
|
||||
## RandLoraModel
|
||||
|
||||
[[autodoc]] tuners.randlora.model.RandLoraModel
|
112
examples/randlora_finetuning/README.md
Normal file
112
examples/randlora_finetuning/README.md
Normal file
@ -0,0 +1,112 @@
|
||||
# RandLora: Full-rank parameter-efficient fine-tuning of large models
|
||||
|
||||
## Introduction
|
||||
[RandLora](https://huggingface.co/papers/2502.00987) is a parameter-efficient fine-tuning technique that is similar to LoRA and VeRA but performs full rank updates to improve performance. RandLora can be particulary usefull when adapting large model to hard tasks that require complex updates while preserving the parameter efficiency of LoRA. The full rank update of RandLora is acheived by linearly scaling random bases. The random bases are a collection of multiple low rank matrices such that the summation of their ranks if greater or equal to the full rank of the parameter matrices. The trainable parameters of RandLora are two diagonal matrices (vectors) that get multiplied with the right hand low rank random bases, in a similar way to VeRA's update. To maintain low memory usage, RandLora uses a custom function that prevents storing unnecessary bases in memory for backpropagation.
|
||||
|
||||
## Quick start
|
||||
```python
|
||||
import torch
|
||||
from peft import RandLoraConfig, get_peft_model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
|
||||
from datasets import load_dataset
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
|
||||
randlora_config = RandLoraConfig()
|
||||
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
trainer = transformers.Trainer(
|
||||
model=peft_model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
peft_model.save_pretrained("randlora-llama-7b")
|
||||
```
|
||||
|
||||
There is no additional change needed to your standard PEFT training procedure, simply swap your `LoraConfig` for a `RandLoraConfig`. Note however that RandLora's trainable parameter count is **inversely proportional** to the rank parameter `r`. Lower `r` to increase and increase it to reduce trainable parameters of RandLora.
|
||||
|
||||
Run the finetuning script simply by running:
|
||||
```bash
|
||||
python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco
|
||||
```
|
||||
This 👆🏻 by default will load the model in peft set up with RandLora config. Now if you wanna quickly compare it with Lora, all you need to do is to input ` --use_lora` in the command line and reduce `--randlora_alpha` to 2x the rank. So same above example would be 👇🏻;
|
||||
|
||||
```bash
|
||||
python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco --use_lora --rank 32 --randlora_alpha 64
|
||||
```
|
||||
|
||||
RandLora can be made to use sparse or very sparse random bases. These sparse matrices can help reduce overfitting. Add `--very_sparse` to run with very sparse matrices or `--sparse` for sparse matrices:
|
||||
|
||||
```bash
|
||||
python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --sparse
|
||||
```
|
||||
|
||||
RandLora also supports quantization. To use 4-bit quantization try:
|
||||
|
||||
```bash
|
||||
python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --quantize
|
||||
```
|
||||
|
||||
By default the RandLora layers are the key and value layers of LLama model. Adding adapters on more layers will increase memory usage. If you wish to choose a different set of layers for RandLora to be applied on, you can simply define it using:
|
||||
```bash
|
||||
python examples/randlora_finetuning/randlora_finetuning.py --randlora_target_modules "q_proj,k_proj,v_proj"
|
||||
```
|
||||
|
||||
### Full example of the script
|
||||
```bash
|
||||
python randlora_finetuning.py \
|
||||
--base_model "PATH_TO_MODEL" \
|
||||
--data_path "PATH_TO_DATASET" \
|
||||
--output_dir "PATH_TO_OUTPUT_DIR" \
|
||||
--batch_size 1 \
|
||||
--num_epochs 3 \
|
||||
--learning_rate 3e-4 \
|
||||
--cutoff_len 512 \
|
||||
--val_set_size 500 \
|
||||
--quantize \
|
||||
--eval_step 10 \
|
||||
--save_step 100 \
|
||||
--device "cuda:0" \
|
||||
--rank 32 \
|
||||
--randlora_alpha 640 \
|
||||
--randlora_dropout 0.05 \
|
||||
--randlora_target_modules "k_proj,v_proj" \
|
||||
--hub_model_id "YOUR_HF_REPO" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## RandLora vs. LoRA
|
||||
RandLora differs from LoRA and other related low rank approximation algorithms by chanllenging the low rank paradigm. RandLora adapters learn **full-rank** updates as the [paper](https://huggingface.co/papers/2502.00987) shows that the low rank constraint of LoRA can constrain performance gains as trainable parameters increase (with higher ranks). As a result, using RandLora is specifically recommended for difficult tasks that are underfit by LoRA. RandLoRA however also often improves performance for common tasks. If increasing LoRA's rank improves performance for your task, RandLora will most likely outperform.
|
||||
|
||||
RandLora is expected to increase performance over LoRA for equivalent amounts of trainable parameters, mostly for larger equivalent amounts (> LoRA rank 4).
|
||||
|
||||
RandLora's performance increase comes with two limitations:
|
||||
|
||||
1. Performance is dependent on using a large `randlora_alpha` scaling parameter (usually 20x the basis rank). This large parameter can sometimes make training the update unstable, reduce the learning rate or the scaling parameter if this is the case.
|
||||
|
||||
2. Increase training time over LoRA when using very low RandLora basis ranks.
|
||||
|
||||
## RandLora vs. VeRA
|
||||
RandLora shares similarities with VeRA in that both algorithms use random basis combinations to address some of LoRA's limitations. The limitations addressed by each algorithm is however different.
|
||||
VeRA aims to reduce trainable parameters beyond rank 1 LoRAs while RandLoRA reduces the performance limitation due to the low rank of the update as the trainable parameter count increases.
|
||||
|
||||
RandLora is expected to:
|
||||
|
||||
1. Improve performance over VeRA when more trainable parameters are required (hard tasks)
|
||||
|
||||
2. Reduce memory usage over VeRA thanks to RandLora's random base sharing strategy
|
||||
|
||||
|
||||
## Citation
|
||||
```
|
||||
@inproceedings{2025_ICLR_RandLoRA,
|
||||
title="{RandLoRA: Full rank parameter-efficient fine-tuning of large models}",
|
||||
author="Albert, Paul and Zhang, Frederic Z. and Saratchandran, Hemanth and Rodriguez-Opazo, Cristian and van den Hengel, Anton and Abbasnejad, Ehsan",
|
||||
booktitle="{International Conference on Learning Representations (ICLR)}",
|
||||
year="2025"
|
||||
}
|
||||
```
|
8099
examples/randlora_finetuning/qrandlora_finetuning.ipynb
Normal file
8099
examples/randlora_finetuning/qrandlora_finetuning.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
225
examples/randlora_finetuning/randlora_finetuning.py
Normal file
225
examples/randlora_finetuning/randlora_finetuning.py
Normal file
@ -0,0 +1,225 @@
|
||||
# This script is based on examples/dora_finetuning/dora_finetuning.py
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
DataCollatorForLanguageModeling,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from peft import LoraConfig, RandLoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
|
||||
def train_model(
|
||||
base_model: str,
|
||||
data_path: str,
|
||||
output_dir: str,
|
||||
batch_size: int,
|
||||
num_epochs: int,
|
||||
learning_rate: float,
|
||||
cutoff_len: int,
|
||||
val_set_size: int,
|
||||
use_lora: bool,
|
||||
quantize: bool,
|
||||
eval_step: int,
|
||||
save_step: int,
|
||||
device: str,
|
||||
rank: int,
|
||||
randlora_alpha: int,
|
||||
randlora_dropout: float,
|
||||
randlora_target_modules: str,
|
||||
hub_model_id: str,
|
||||
push_to_hub: bool,
|
||||
sparse: bool,
|
||||
very_sparse: bool,
|
||||
):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
|
||||
# Setup device
|
||||
device = torch.device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token)
|
||||
|
||||
# Compute type
|
||||
torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
||||
# QRandLora (quantized randlora): IF YOU WANNA QUANTIZE THE MODEL
|
||||
if quantize:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
token=hf_token,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=(
|
||||
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
||||
),
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
),
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# setup for quantized training
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
torch_dtype=torch_dtype,
|
||||
token=hf_token,
|
||||
)
|
||||
# LoRa config for the PEFT model
|
||||
if use_lora:
|
||||
peft_config = LoraConfig(
|
||||
r=rank, # Rank of matrix
|
||||
lora_alpha=randlora_alpha,
|
||||
target_modules=(randlora_target_modules.split(",") if randlora_target_modules else ["k_proj", "v_proj"]),
|
||||
lora_dropout=randlora_dropout,
|
||||
bias="none",
|
||||
)
|
||||
else:
|
||||
peft_config = RandLoraConfig(
|
||||
r=rank, # Rank of random bases
|
||||
randlora_alpha=randlora_alpha,
|
||||
target_modules=(randlora_target_modules.split(",") if randlora_target_modules else ["k_proj", "v_proj"]),
|
||||
randlora_dropout=randlora_dropout,
|
||||
bias="none",
|
||||
sparse=sparse,
|
||||
very_sparse=very_sparse,
|
||||
)
|
||||
|
||||
# get the peft model with RandLora config
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
model.to(device) # MODEL TO GPU/CUDA
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(data_path)
|
||||
|
||||
def tokenize_function(examples):
|
||||
inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len)
|
||||
inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task
|
||||
return inputs
|
||||
|
||||
# Tokenize the dataset and prepare for training
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
|
||||
# Data collator to dynamically pad the batched examples
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
|
||||
# Compute the total amount of training step for warmup
|
||||
max_steps = int((len(dataset) // batch_size) * num_epochs)
|
||||
|
||||
# Define training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
warmup_steps=int(max_steps * 0.1), # 10% of total trainig steps
|
||||
weight_decay=0.01,
|
||||
logging_dir="./logs",
|
||||
logging_steps=eval_step,
|
||||
save_steps=save_step,
|
||||
save_total_limit=2,
|
||||
push_to_hub=push_to_hub,
|
||||
hub_model_id=hub_model_id,
|
||||
gradient_accumulation_steps=16
|
||||
// batch_size, # Maintaining a minimum batch size of 16 post accumulation is recommended to ensure good performance
|
||||
learning_rate=learning_rate,
|
||||
hub_token=hf_token,
|
||||
label_names=["labels"],
|
||||
)
|
||||
|
||||
# Clear CUDA cache to free memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Initialize the Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["test"],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Start model training
|
||||
trainer.train()
|
||||
|
||||
# Save and push the trained model and tokenizer
|
||||
if push_to_hub:
|
||||
# Push the main model to the hub
|
||||
trainer.push_to_hub(commit_message="Fine-tuned model")
|
||||
|
||||
# Save the model and tokenizer locally
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DoRA and PEFT")
|
||||
parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name")
|
||||
parser.add_argument(
|
||||
"--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--learning_rate", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
|
||||
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size")
|
||||
parser.add_argument("--use_lora", action="store_true", help="Apply Lora instead of RandLora")
|
||||
parser.add_argument("--quantize", action="store_true", help="Use quantization")
|
||||
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
|
||||
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
|
||||
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training")
|
||||
parser.add_argument("--rank", type=int, default=32, help="RandLora basis rank")
|
||||
parser.add_argument("--randlora_alpha", type=int, default=640, help="RandLora alpha")
|
||||
parser.add_argument("--randlora_dropout", type=float, default=0.05, help="RandLora dropout rate")
|
||||
parser.add_argument(
|
||||
"--randlora_target_modules", type=str, default=None, help="Comma-separated list of target modules for RandLora"
|
||||
)
|
||||
parser.add_argument("--sparse", action="store_true", help="Use sparse matrix multiplication")
|
||||
parser.add_argument("--very_sparse", action="store_true", help="Use very sparse matrix multiplication")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default="path/to/repo",
|
||||
help="Repository name to push the model on the Hugging Face Hub",
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub")
|
||||
args = parser.parse_args()
|
||||
train_model(
|
||||
base_model=args.base_model,
|
||||
data_path=args.data_path,
|
||||
output_dir=args.output_dir,
|
||||
batch_size=args.batch_size,
|
||||
num_epochs=args.num_epochs,
|
||||
learning_rate=args.learning_rate,
|
||||
cutoff_len=args.cutoff_len,
|
||||
val_set_size=args.val_set_size,
|
||||
use_lora=args.use_lora,
|
||||
quantize=args.quantize,
|
||||
eval_step=args.eval_step,
|
||||
save_step=args.save_step,
|
||||
device=args.device,
|
||||
rank=args.rank,
|
||||
randlora_alpha=args.randlora_alpha,
|
||||
randlora_dropout=args.randlora_dropout,
|
||||
randlora_target_modules=args.randlora_target_modules,
|
||||
hub_model_id=args.hub_model_id,
|
||||
push_to_hub=args.push_to_hub,
|
||||
sparse=args.sparse,
|
||||
very_sparse=args.very_sparse,
|
||||
)
|
@ -0,0 +1,22 @@
|
||||
{
|
||||
"auto_mapping": null,
|
||||
"base_model_name_or_path": null,
|
||||
"bias": "none",
|
||||
"fan_in_fan_out": false,
|
||||
"inference_mode": false,
|
||||
"init_weights": true,
|
||||
"layers_pattern": null,
|
||||
"layers_to_transform": null,
|
||||
"modules_to_save": null,
|
||||
"peft_type": "RANDLORA",
|
||||
"projection_prng_key": 0,
|
||||
"r": 32,
|
||||
"randlora_alpha": 640,
|
||||
"randlora_dropout": 0.0,
|
||||
"revision": null,
|
||||
"save_projection": true,
|
||||
"sparse": false,
|
||||
"target_modules": null,
|
||||
"task_type": null,
|
||||
"very_sparse": false
|
||||
}
|
@ -324,7 +324,6 @@ class Linear(nn.Linear, RandLoraLayer):
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
|
@ -56,7 +56,10 @@ def _kaiming_init(
|
||||
`torch.Tensor`: The initialised tensor.
|
||||
"""
|
||||
if isinstance(tensor_or_shape, tuple):
|
||||
tensor = torch.empty(tensor_or_shape, dtype=torch.float32)
|
||||
tensor = torch.empty(
|
||||
tensor_or_shape,
|
||||
dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
|
||||
)
|
||||
else:
|
||||
tensor = tensor_or_shape
|
||||
|
||||
@ -86,7 +89,7 @@ class RandLoraModel(BaseTuner):
|
||||
>>> from peft import RandLoraConfig, get_peft_model
|
||||
|
||||
>>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
>>> config = RandLoraConfig(r=128)
|
||||
>>> config = RandLoraConfig(r=32)
|
||||
>>> model = get_peft_model(base_model, config)
|
||||
```
|
||||
|
||||
|
Reference in New Issue
Block a user