Compare commits

...

4 Commits

Author SHA1 Message Date
431c0e2d5c Release 0.13.2 (patch release for #2144) 2024-10-11 13:37:09 +02:00
dd4ce0365c FIX Bug in target module optimization if suffix (#2144)
Solves the following bug:

https://github.com/huggingface/diffusers/pull/9622#issuecomment-2404789721

The cause for the bug is as follows: When we have, say, a module called
"bar.0.query" that we want to target and another module called
"foo_bar.0.query" that we don't want to target, there was potential for
an error. This is not caused by _find_minimal_target_modules directly,
but rather the bug was inside of BaseTuner.inject_adapter and how the
names_no_target were chosen. Those used to be chosen based on suffix. In
our example, however, "bar.0.query" is a suffix of "foo_bar.0.query",
therefore "foo_bar.0.query" was *not* added to names_no_target when it
should have. As a consequence, during the optimization, it looks like
"query" is safe to use as target_modules because we don't see that it
wrongly matches "foo_bar.0.query".
2024-10-11 13:32:36 +02:00
b8da272660 Release 0.13.1 (patch release for #2113) 2024-10-08 14:17:55 +02:00
61c57f4f65 FIX low_cpu_mem_usage consolidates devices (#2113)
See: https://github.com/huggingface/diffusers/pull/9510#issuecomment-2378316687

Right now, the low_cpu_mem_usage=True option does not consolidate the
devices. E.g. when the model is on GPU and the state_dict on CPU, the
adapter weight will be on CPU after loading, when it should be GPU. This
fix ensures that the devices are consolidated.
2024-10-08 14:16:53 +02:00
6 changed files with 102 additions and 3 deletions

View File

@ -15,7 +15,7 @@
from setuptools import find_packages, setup
VERSION = "0.13.0"
VERSION = "0.13.2"
extras = {}
extras["quality"] = [

View File

@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.13.0"
__version__ = "0.13.2"
from .auto import (
AutoPeftModel,

View File

@ -460,7 +460,9 @@ class BaseTuner(nn.Module, ABC):
and len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION
):
names_no_target = [
name for name in key_list if not any(name.endswith(suffix) for suffix in peft_config.target_modules)
name
for name in key_list
if not any((name == suffix) or name.endswith("." + suffix) for suffix in peft_config.target_modules)
]
new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target)
if len(new_target_modules) < len(peft_config.target_modules):

View File

@ -456,6 +456,10 @@ def set_peft_model_state_dict(
)
if low_cpu_mem_usage:
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
# ensure that the correct device is set
for module in model.modules():
if hasattr(module, "_move_adapter_to_device_of_base_layer"):
module._move_adapter_to_device_of_base_layer(adapter_name)
else:
load_result = model.load_state_dict(peft_model_state_dict, strict=False)

View File

@ -55,8 +55,11 @@ from peft import (
PromptEncoderConfig,
TaskType,
get_peft_model,
get_peft_model_state_dict,
inject_adapter_in_model,
prepare_model_for_kbit_training,
replace_lora_weights_loftq,
set_peft_model_state_dict,
)
from peft.tuners import boft
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
@ -3226,3 +3229,51 @@ class TestPTuningReproducibility:
torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestLowCpuMemUsageDifferentDevices:
"""Test for the low CPU memory usage option for loading PEFT models.
There are already tests for this in test_initialization.py but here we want to specifically test diverging devices
for the model and state_dict.
"""
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
@pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")])
def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
inputs = {k: v.to(device_model) for k, v in inputs.items()}
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")
model = get_peft_model(model, lora_config)
model.eval()
logits_not_low_cpu_mem = model(**inputs).logits
state_dict = get_peft_model_state_dict(model)
peft_model_state_dict = {}
# remap the state dict so that it can be correctly loaded, and move weights to the other device
prefix = "base_model.model."
for k, v in state_dict.items():
k = k[len(prefix) :]
peft_model_state_dict[k] = v.to(device_sd)
del model
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
model.eval()
inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)
load_result = set_peft_model_state_dict(model, peft_model_state_dict, low_cpu_mem_usage=True)
# sanity check: all lora keys are matched
assert not any("lora" in k for k in load_result.missing_keys)
assert not any("lora" in k for k in load_result.unexpected_keys)
logits_low_cpu_mem = model(**inputs).logits
assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
assert {p.device.type for p in model.parameters()} == {device_model}

View File

@ -1327,3 +1327,45 @@ class TestFindMinimalTargetModules:
expected = {"time_emb_proj", "proj", "proj_out"}
result = find_minimal_target_modules(target_modules, other_module_names)
assert result == expected
def test_get_peft_modules_module_name_is_suffix_of_another_module(self):
# Solves the following bug:
# https://github.com/huggingface/diffusers/pull/9622#issuecomment-2404789721
# The cause for the bug is as follows: When we have, say, a module called "bar.0.query" that we want to target
# and another module called "foo_bar.0.query" that we don't want to target, there was potential for an error.
# This is not caused by _find_minimal_target_modules directly, but rather the bug was inside of
# BaseTuner.inject_adapter and how the names_no_target were chosen. Those used to be chosen based on suffix. In
# our example, however, "bar.0.query" is a suffix of "foo_bar.0.query", therefore "foo_bar.0.query" was *not*
# added to names_no_target when it should have. As a consequence, during the optimization, it looks like "query"
# is safe to use as target_modules because we don't see that it wrongly matches "foo_bar.0.query".
# ensure that we have sufficiently many modules to trigger the optimization
n_layers = MIN_TARGET_MODULES_FOR_OPTIMIZATION + 1
class InnerModule(nn.Module):
def __init__(self):
super().__init__()
self.query = nn.Linear(10, 10)
class OuterModule(nn.Module):
def __init__(self):
super().__init__()
# note that "transformer_blocks" is a suffix of "single_transformer_blocks"
self.transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])
self.single_transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])
# we want to match all "transformer_blocks" layers but not "single_transformer_blocks"
target_modules = [f"transformer_blocks.{i}.query" for i in range(n_layers)]
model = get_peft_model(OuterModule(), LoraConfig(target_modules=target_modules))
# sanity check: we should have n_layers PEFT layers in model.transformer_blocks
transformer_blocks = model.base_model.model.transformer_blocks
assert sum(isinstance(module, BaseTunerLayer) for module in transformer_blocks.modules()) == n_layers
# we should not have any PEFT layers in model.single_transformer_blocks
single_transformer_blocks = model.base_model.model.single_transformer_blocks
assert not any(isinstance(module, BaseTunerLayer) for module in single_transformer_blocks.modules())
# target modules should *not* be simplified to "query" as that would match "single_transformers_blocks" too
assert model.peft_config["default"].target_modules != {"query"}