mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 23:43:47 +08:00
Compare commits
1 Commits
v0.15.0
...
patch-rele
Author | SHA1 | Date | |
---|---|---|---|
c42eb227d9 |
2
setup.py
2
setup.py
@ -15,7 +15,7 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
VERSION = "0.15.0"
|
||||
VERSION = "0.15.1"
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = [
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.15.0"
|
||||
__version__ = "0.15.1"
|
||||
|
||||
from .auto import (
|
||||
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,
|
||||
|
@ -434,9 +434,10 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
|
||||
"""Return a mapping from the key present in disk-loaded state dict
|
||||
and how it should be represented in the loaded model's state dict.
|
||||
|
||||
If a key is not present here, it is assumed to be mapped 1:1.
|
||||
The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the
|
||||
ground-truth for which keys are supposed to be loaded from a saved state dict.
|
||||
"""
|
||||
return {}
|
||||
raise NotImplementedError
|
||||
|
||||
def unload_and_optionally_merge_module(
|
||||
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
|
||||
@ -550,15 +551,24 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
|
||||
self._active_adapter = adapter_name
|
||||
|
||||
def adapter_state_dict_load_map(self, adapter_name):
|
||||
# The state dict returned by ModulesToSaveWrapper
|
||||
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.adapter_state_dict(adapter_name)}
|
||||
|
||||
def adapter_state_dict(self, adapter_name):
|
||||
# Maps the module keys as they are in the saved state dict to the in-memory state dict.
|
||||
# Must contain all keys that are supposed to be loaded.
|
||||
if adapter_name not in self._adapters:
|
||||
# In caes of multiple adapters, each bringing their own modules to save, each
|
||||
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
|
||||
return {}
|
||||
return self.modules_to_save[adapter_name].state_dict()
|
||||
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.modules_to_save[adapter_name].state_dict()}
|
||||
|
||||
def adapter_state_dict(self, adapter_name, state_dict):
|
||||
if adapter_name not in self._adapters:
|
||||
# In caes of multiple adapters, each bringing their own modules to save, each
|
||||
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
|
||||
return {}
|
||||
|
||||
return {
|
||||
k: state_dict[f"modules_to_save.{adapter_name}.{k}"]
|
||||
for k in self.modules_to_save[adapter_name].state_dict()
|
||||
}
|
||||
|
||||
def unload_and_optionally_merge_module(
|
||||
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
|
||||
@ -651,7 +661,12 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
|
||||
|
||||
super().update(active_adapter)
|
||||
|
||||
def adapter_state_dict(self, adapter_name):
|
||||
def adapter_state_dict_load_map(self, adapter_name):
|
||||
if self.token_adapter.tied_adapter:
|
||||
return {}
|
||||
return {"token_adapter.trainable_tokens_delta": f"token_adapter.trainable_tokens_delta.{adapter_name}"}
|
||||
|
||||
def adapter_state_dict(self, adapter_name, state_dict):
|
||||
if self.token_adapter.tied_adapter:
|
||||
# storing of weight-tied layers is not up to us and will be handled by
|
||||
# transformers. we're just here to keep those layers in sync during training.
|
||||
@ -659,9 +674,7 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
|
||||
return {}
|
||||
|
||||
return {
|
||||
f"token_adapter.{k}": v
|
||||
for k, v in self.token_adapter.state_dict().items()
|
||||
if k.startswith("trainable_tokens_") and k.endswith(f".{adapter_name}")
|
||||
f"token_adapter.{k}": state_dict[f"token_adapter.{k}.{adapter_name}"] for k in ["trainable_tokens_delta"]
|
||||
}
|
||||
|
||||
def enable_adapters(self, enabled: bool):
|
||||
|
@ -197,7 +197,16 @@ def get_peft_model_state_dict(
|
||||
# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, AuxiliaryTrainingWrapper):
|
||||
to_return.update({f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name).items()})
|
||||
# Compute the module-relative state dict to make it easier for the adapter to fetch the appropriate
|
||||
# keys that the module thinks need to be saved. We cannot rely on `.state_dict()` internally of the
|
||||
# module since accelerators like DeepSpeed require special handling which is done for the model
|
||||
# state dict from above but most likely not in the module itself. See #2450.
|
||||
module_state_dict = {
|
||||
k.removeprefix(f"{name}."): v for k, v in state_dict.items() if k.startswith(f"{name}.")
|
||||
}
|
||||
to_return.update(
|
||||
{f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()}
|
||||
)
|
||||
|
||||
# DEAL WITH EMBEDDINGS
|
||||
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
|
||||
@ -343,14 +352,9 @@ def set_peft_model_state_dict(
|
||||
# `modules_to_save.{adapter_name}.` prefix. This prefix must be restored when loading the model from the
|
||||
# saved state dict which is why we fetch a load key map from the wrapper.
|
||||
key_map = module.adapter_state_dict_load_map(adapter_name)
|
||||
|
||||
for k in module.adapter_state_dict(adapter_name):
|
||||
# each saved state dict is adapter specific, i.e. does not contain the adapter name
|
||||
# but the loaded state dict does include adapter names since we can have multiple.
|
||||
k_no_adapter = k.replace(f".{adapter_name}", "")
|
||||
|
||||
store_key = f"{name}.{key_map.get(k, k)}"
|
||||
lookup_key = f"{name}.{k_no_adapter}"
|
||||
for k in key_map:
|
||||
lookup_key = f"{name}.{k}"
|
||||
store_key = f"{name}.{key_map[k]}"
|
||||
|
||||
state_dict[store_key] = peft_model_state_dict[lookup_key]
|
||||
|
||||
|
Reference in New Issue
Block a user