Files
peft/examples/shira_finetuning
Shantanu Gupta 1a1f97263d CHORE Replace deprecated torch_dtype with dtype (#2837)
Note: Diffusers is left as is for now, might need an update later.
2025-10-16 14:59:09 +02:00
..

Sparse High Rank Adapters

Introduction

Sparse High Rank Adapters or SHiRA is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model's parameters to finetune the model on any adaptation task.

Quick start

import torch
from peft import ShiraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
dataset = load_dataset("imdb", split="train[:1%]")
shira_config = ShiraConfig(
    r=32,
)
peft_model = get_peft_model(model, shira_config)
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
    model=peft_model,
    train_dataset=dataset,
    processing_class=tokenizer,
)
trainer.train()
peft_model.save_pretrained("shira-opt-350m")

For more options and a more detailed example code, you can refer to shira finetuning script. Run the script simply by running:

python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m

If you want to run DDP by accelerate, please run accelerate config to set your ddp config, and run:

accelerate launch examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m

please add --device_map cpu if you want to run finetune on CPU.

If you want to train SHiRA with a custom sparse mask function which requires custom keyword arguments, please see the definition of custom_random_mask_function_with_custom_kwargs function provided in the shira_fintuning.py script. You can run this code using the --use_custom_random_mask_function_with_custom_kwargs argument. Without this argument, SHiRA defaults to a random sparse mask. Please run the code as follows. :

python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m --use_custom_random_mask_function_with_custom_kwargs

Use the model

You can load and use the model as any other 🤗 PEFT model

from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
shira_model = PeftModel.from_pretrained(model, "shira-opt-350m")

Citation

@inproceedings{NEURIPS2024_18c0102c,
 author = {Bhardwaj, Kartikeya and Pandey, Nilesh Prasad and Priyadarshi, Sweta and Ganapathy, Viswanath and Kadambi, Shreya and Esteves, Rafael and Borse, Shubhankar and Whatmough, Paul and Garrepalli, Risheek and Van Baalen, Mart and Teague, Harris and Nagel, Markus},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
 pages = {13685--13715},
 publisher = {Curran Associates, Inc.},
 title = {Sparse High Rank Adapters},
 url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/18c0102cb7f1a02c14f0929089b2e576-Paper-Conference.pdf},
 volume = {37},
 year = {2024}
}