mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
ENH BOFT don't save boft_P buffer (#2050)
The buffer does not need to be part of the checkpoint, by making it non-persistent, the file size can be greatly reduced.
This commit is contained in:
@ -337,7 +337,7 @@ class BOFTLayer(BaseTunerLayer):
|
||||
perm_mat = self.perm2mat(perm)
|
||||
P[i] = perm_mat
|
||||
|
||||
self.register_buffer("boft_P", P)
|
||||
self.register_buffer("boft_P", P, persistent=False)
|
||||
|
||||
self.boft_R[adapter_name] = nn.Parameter(
|
||||
torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size)
|
||||
@ -771,7 +771,7 @@ class Conv2d(nn.Module, BOFTLayer):
|
||||
perm_mat = self.perm2mat(perm)
|
||||
P[i] = perm_mat
|
||||
|
||||
self.register_buffer("boft_P", P)
|
||||
self.register_buffer("boft_P", P, persistent=False)
|
||||
|
||||
self.boft_R[adapter_name] = nn.Parameter(
|
||||
torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size)
|
||||
|
84
tests/test_boft.py
Normal file
84
tests/test_boft.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from peft import BOFTConfig, PeftModel, get_peft_model
|
||||
from peft.utils import infer_device
|
||||
|
||||
|
||||
class TestBoft:
|
||||
device = infer_device()
|
||||
|
||||
def test_boft_state_dict(self, tmp_path):
|
||||
# see #2050
|
||||
# ensure that the boft_P buffer is not stored in the checkpoint file and is not necessary to load the model
|
||||
# correctly
|
||||
torch.manual_seed(0)
|
||||
|
||||
inputs = torch.arange(10).view(-1, 1).to(self.device)
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||
model.eval()
|
||||
output_base = model(inputs).logits
|
||||
|
||||
config = BOFTConfig(init_weights=False)
|
||||
model = get_peft_model(model, config)
|
||||
model.eval()
|
||||
output_peft = model(inputs).logits
|
||||
|
||||
atol, rtol = 1e-5, 1e-8
|
||||
# sanity check: loading boft changed the output
|
||||
assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol)
|
||||
|
||||
model.save_pretrained(tmp_path)
|
||||
del model
|
||||
|
||||
# check that the boft_P buffer is not present
|
||||
state_dict = load_file(tmp_path / "adapter_model.safetensors")
|
||||
assert not any("boft_P" in key for key in state_dict)
|
||||
|
||||
# sanity check: the model still produces the same output after loading
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||
model = PeftModel.from_pretrained(model, tmp_path)
|
||||
output_loaded = model(inputs).logits
|
||||
assert torch.allclose(output_peft, output_loaded, atol=atol, rtol=rtol)
|
||||
|
||||
def test_boft_old_checkpoint_including_boft_P(self, tmp_path):
|
||||
# see #2050
|
||||
# This test exists to ensure that after the boft_P buffer was made non-persistent, old checkpoints can still be
|
||||
# loaded successfully.
|
||||
torch.manual_seed(0)
|
||||
|
||||
inputs = torch.arange(10).view(-1, 1).to(self.device)
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||
|
||||
# first create the expected output
|
||||
config = BOFTConfig(init_weights=False)
|
||||
model = get_peft_model(model, config)
|
||||
model.eval()
|
||||
output_peft = model(inputs).logits
|
||||
del model
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||
# checkpoint from before the PR whose state_dict still contains boft_P
|
||||
hub_id = "peft-internal-testing/boft-tiny-opt-peft-v0.12"
|
||||
model = PeftModel.from_pretrained(model, hub_id)
|
||||
output_old = model(inputs).logits
|
||||
|
||||
atol, rtol = 1e-5, 1e-8
|
||||
assert torch.allclose(output_peft, output_old, atol=atol, rtol=rtol)
|
Reference in New Issue
Block a user