diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 05e145b7..c6ad31bd 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -337,7 +337,7 @@ class BOFTLayer(BaseTunerLayer): perm_mat = self.perm2mat(perm) P[i] = perm_mat - self.register_buffer("boft_P", P) + self.register_buffer("boft_P", P, persistent=False) self.boft_R[adapter_name] = nn.Parameter( torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size) @@ -771,7 +771,7 @@ class Conv2d(nn.Module, BOFTLayer): perm_mat = self.perm2mat(perm) P[i] = perm_mat - self.register_buffer("boft_P", P) + self.register_buffer("boft_P", P, persistent=False) self.boft_R[adapter_name] = nn.Parameter( torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size) diff --git a/tests/test_boft.py b/tests/test_boft.py new file mode 100644 index 00000000..d0cf74e3 --- /dev/null +++ b/tests/test_boft.py @@ -0,0 +1,84 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from safetensors.torch import load_file +from transformers import AutoModelForCausalLM + +from peft import BOFTConfig, PeftModel, get_peft_model +from peft.utils import infer_device + + +class TestBoft: + device = infer_device() + + def test_boft_state_dict(self, tmp_path): + # see #2050 + # ensure that the boft_P buffer is not stored in the checkpoint file and is not necessary to load the model + # correctly + torch.manual_seed(0) + + inputs = torch.arange(10).view(-1, 1).to(self.device) + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + model.eval() + output_base = model(inputs).logits + + config = BOFTConfig(init_weights=False) + model = get_peft_model(model, config) + model.eval() + output_peft = model(inputs).logits + + atol, rtol = 1e-5, 1e-8 + # sanity check: loading boft changed the output + assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol) + + model.save_pretrained(tmp_path) + del model + + # check that the boft_P buffer is not present + state_dict = load_file(tmp_path / "adapter_model.safetensors") + assert not any("boft_P" in key for key in state_dict) + + # sanity check: the model still produces the same output after loading + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + model = PeftModel.from_pretrained(model, tmp_path) + output_loaded = model(inputs).logits + assert torch.allclose(output_peft, output_loaded, atol=atol, rtol=rtol) + + def test_boft_old_checkpoint_including_boft_P(self, tmp_path): + # see #2050 + # This test exists to ensure that after the boft_P buffer was made non-persistent, old checkpoints can still be + # loaded successfully. + torch.manual_seed(0) + + inputs = torch.arange(10).view(-1, 1).to(self.device) + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + + # first create the expected output + config = BOFTConfig(init_weights=False) + model = get_peft_model(model, config) + model.eval() + output_peft = model(inputs).logits + del model + + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + # checkpoint from before the PR whose state_dict still contains boft_P + hub_id = "peft-internal-testing/boft-tiny-opt-peft-v0.12" + model = PeftModel.from_pretrained(model, hub_id) + output_old = model(inputs).logits + + atol, rtol = 1e-5, 1e-8 + assert torch.allclose(output_peft, output_old, atol=atol, rtol=rtol)