mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 23:43:47 +08:00
Compare commits
4 Commits
patch-rele
...
v0.13.2
Author | SHA1 | Date | |
---|---|---|---|
431c0e2d5c | |||
dd4ce0365c | |||
b8da272660 | |||
61c57f4f65 |
2
setup.py
2
setup.py
@ -15,7 +15,7 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
VERSION = "0.13.0"
|
||||
VERSION = "0.13.2"
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = [
|
||||
|
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.13.0"
|
||||
__version__ = "0.13.2"
|
||||
|
||||
from .auto import (
|
||||
AutoPeftModel,
|
||||
|
@ -460,7 +460,9 @@ class BaseTuner(nn.Module, ABC):
|
||||
and len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION
|
||||
):
|
||||
names_no_target = [
|
||||
name for name in key_list if not any(name.endswith(suffix) for suffix in peft_config.target_modules)
|
||||
name
|
||||
for name in key_list
|
||||
if not any((name == suffix) or name.endswith("." + suffix) for suffix in peft_config.target_modules)
|
||||
]
|
||||
new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target)
|
||||
if len(new_target_modules) < len(peft_config.target_modules):
|
||||
|
@ -456,6 +456,10 @@ def set_peft_model_state_dict(
|
||||
)
|
||||
if low_cpu_mem_usage:
|
||||
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
|
||||
# ensure that the correct device is set
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_move_adapter_to_device_of_base_layer"):
|
||||
module._move_adapter_to_device_of_base_layer(adapter_name)
|
||||
else:
|
||||
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
|
||||
|
||||
|
@ -55,8 +55,11 @@ from peft import (
|
||||
PromptEncoderConfig,
|
||||
TaskType,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
inject_adapter_in_model,
|
||||
prepare_model_for_kbit_training,
|
||||
replace_lora_weights_loftq,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
from peft.tuners import boft
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
|
||||
@ -3226,3 +3229,51 @@ class TestPTuningReproducibility:
|
||||
|
||||
torch.testing.assert_close(output_loaded, output_peft)
|
||||
torch.testing.assert_close(gen_loaded, gen_peft)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
|
||||
@pytest.mark.single_gpu_tests
|
||||
class TestLowCpuMemUsageDifferentDevices:
|
||||
"""Test for the low CPU memory usage option for loading PEFT models.
|
||||
|
||||
There are already tests for this in test_initialization.py but here we want to specifically test diverging devices
|
||||
for the model and state_dict.
|
||||
|
||||
"""
|
||||
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
|
||||
@pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")])
|
||||
def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
|
||||
inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
|
||||
inputs = {k: v.to(device_model) for k, v in inputs.items()}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
|
||||
lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.eval()
|
||||
logits_not_low_cpu_mem = model(**inputs).logits
|
||||
|
||||
state_dict = get_peft_model_state_dict(model)
|
||||
peft_model_state_dict = {}
|
||||
# remap the state dict so that it can be correctly loaded, and move weights to the other device
|
||||
prefix = "base_model.model."
|
||||
for k, v in state_dict.items():
|
||||
k = k[len(prefix) :]
|
||||
peft_model_state_dict[k] = v.to(device_sd)
|
||||
|
||||
del model
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
|
||||
model.eval()
|
||||
inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)
|
||||
load_result = set_peft_model_state_dict(model, peft_model_state_dict, low_cpu_mem_usage=True)
|
||||
|
||||
# sanity check: all lora keys are matched
|
||||
assert not any("lora" in k for k in load_result.missing_keys)
|
||||
assert not any("lora" in k for k in load_result.unexpected_keys)
|
||||
|
||||
logits_low_cpu_mem = model(**inputs).logits
|
||||
|
||||
assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
|
||||
assert {p.device.type for p in model.parameters()} == {device_model}
|
||||
|
@ -1327,3 +1327,45 @@ class TestFindMinimalTargetModules:
|
||||
expected = {"time_emb_proj", "proj", "proj_out"}
|
||||
result = find_minimal_target_modules(target_modules, other_module_names)
|
||||
assert result == expected
|
||||
|
||||
def test_get_peft_modules_module_name_is_suffix_of_another_module(self):
|
||||
# Solves the following bug:
|
||||
# https://github.com/huggingface/diffusers/pull/9622#issuecomment-2404789721
|
||||
|
||||
# The cause for the bug is as follows: When we have, say, a module called "bar.0.query" that we want to target
|
||||
# and another module called "foo_bar.0.query" that we don't want to target, there was potential for an error.
|
||||
# This is not caused by _find_minimal_target_modules directly, but rather the bug was inside of
|
||||
# BaseTuner.inject_adapter and how the names_no_target were chosen. Those used to be chosen based on suffix. In
|
||||
# our example, however, "bar.0.query" is a suffix of "foo_bar.0.query", therefore "foo_bar.0.query" was *not*
|
||||
# added to names_no_target when it should have. As a consequence, during the optimization, it looks like "query"
|
||||
# is safe to use as target_modules because we don't see that it wrongly matches "foo_bar.0.query".
|
||||
|
||||
# ensure that we have sufficiently many modules to trigger the optimization
|
||||
n_layers = MIN_TARGET_MODULES_FOR_OPTIMIZATION + 1
|
||||
|
||||
class InnerModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.query = nn.Linear(10, 10)
|
||||
|
||||
class OuterModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# note that "transformer_blocks" is a suffix of "single_transformer_blocks"
|
||||
self.transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])
|
||||
self.single_transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])
|
||||
|
||||
# we want to match all "transformer_blocks" layers but not "single_transformer_blocks"
|
||||
target_modules = [f"transformer_blocks.{i}.query" for i in range(n_layers)]
|
||||
model = get_peft_model(OuterModule(), LoraConfig(target_modules=target_modules))
|
||||
|
||||
# sanity check: we should have n_layers PEFT layers in model.transformer_blocks
|
||||
transformer_blocks = model.base_model.model.transformer_blocks
|
||||
assert sum(isinstance(module, BaseTunerLayer) for module in transformer_blocks.modules()) == n_layers
|
||||
|
||||
# we should not have any PEFT layers in model.single_transformer_blocks
|
||||
single_transformer_blocks = model.base_model.model.single_transformer_blocks
|
||||
assert not any(isinstance(module, BaseTunerLayer) for module in single_transformer_blocks.modules())
|
||||
|
||||
# target modules should *not* be simplified to "query" as that would match "single_transformers_blocks" too
|
||||
assert model.peft_config["default"].target_modules != {"query"}
|
||||
|
Reference in New Issue
Block a user