mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
`AuxiliaryTrainingWrapper.adapter_state_dict` now utilizes an external state dict for the computation of the module state dict to avoid problems with DeepSpeed (or FSDP) when dealing with distributed parameters. It is not possible to simply wrap everything in `GatheredParameters` context managers since doing that leads to a deadlock when running on more than one process (reasons unclear). Since transformers, or more specifically, accelerate already handles state dict fetching for the whole model, it is more economical to use that state dict and rewrite the methods that before depended on `state_dict()` calls.
This commit is contained in:
@ -434,9 +434,10 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
|
||||
"""Return a mapping from the key present in disk-loaded state dict
|
||||
and how it should be represented in the loaded model's state dict.
|
||||
|
||||
If a key is not present here, it is assumed to be mapped 1:1.
|
||||
The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the
|
||||
ground-truth for which keys are supposed to be loaded from a saved state dict.
|
||||
"""
|
||||
return {}
|
||||
raise NotImplementedError
|
||||
|
||||
def unload_and_optionally_merge_module(
|
||||
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
|
||||
@ -550,15 +551,24 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
|
||||
self._active_adapter = adapter_name
|
||||
|
||||
def adapter_state_dict_load_map(self, adapter_name):
|
||||
# The state dict returned by ModulesToSaveWrapper
|
||||
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.adapter_state_dict(adapter_name)}
|
||||
|
||||
def adapter_state_dict(self, adapter_name):
|
||||
# Maps the module keys as they are in the saved state dict to the in-memory state dict.
|
||||
# Must contain all keys that are supposed to be loaded.
|
||||
if adapter_name not in self._adapters:
|
||||
# In caes of multiple adapters, each bringing their own modules to save, each
|
||||
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
|
||||
return {}
|
||||
return self.modules_to_save[adapter_name].state_dict()
|
||||
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.modules_to_save[adapter_name].state_dict()}
|
||||
|
||||
def adapter_state_dict(self, adapter_name, state_dict):
|
||||
if adapter_name not in self._adapters:
|
||||
# In caes of multiple adapters, each bringing their own modules to save, each
|
||||
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
|
||||
return {}
|
||||
|
||||
return {
|
||||
k: state_dict[f"modules_to_save.{adapter_name}.{k}"]
|
||||
for k in self.modules_to_save[adapter_name].state_dict()
|
||||
}
|
||||
|
||||
def unload_and_optionally_merge_module(
|
||||
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
|
||||
@ -651,7 +661,12 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
|
||||
|
||||
super().update(active_adapter)
|
||||
|
||||
def adapter_state_dict(self, adapter_name):
|
||||
def adapter_state_dict_load_map(self, adapter_name):
|
||||
if self.token_adapter.tied_adapter:
|
||||
return {}
|
||||
return {"token_adapter.trainable_tokens_delta": f"token_adapter.trainable_tokens_delta.{adapter_name}"}
|
||||
|
||||
def adapter_state_dict(self, adapter_name, state_dict):
|
||||
if self.token_adapter.tied_adapter:
|
||||
# storing of weight-tied layers is not up to us and will be handled by
|
||||
# transformers. we're just here to keep those layers in sync during training.
|
||||
@ -659,9 +674,7 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
|
||||
return {}
|
||||
|
||||
return {
|
||||
f"token_adapter.{k}": v
|
||||
for k, v in self.token_adapter.state_dict().items()
|
||||
if k.startswith("trainable_tokens_") and k.endswith(f".{adapter_name}")
|
||||
f"token_adapter.{k}": state_dict[f"token_adapter.{k}.{adapter_name}"] for k in ["trainable_tokens_delta"]
|
||||
}
|
||||
|
||||
def enable_adapters(self, enabled: bool):
|
||||
|
@ -197,7 +197,16 @@ def get_peft_model_state_dict(
|
||||
# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, AuxiliaryTrainingWrapper):
|
||||
to_return.update({f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name).items()})
|
||||
# Compute the module-relative state dict to make it easier for the adapter to fetch the appropriate
|
||||
# keys that the module thinks need to be saved. We cannot rely on `.state_dict()` internally of the
|
||||
# module since accelerators like DeepSpeed require special handling which is done for the model
|
||||
# state dict from above but most likely not in the module itself. See #2450.
|
||||
module_state_dict = {
|
||||
k.removeprefix(f"{name}."): v for k, v in state_dict.items() if k.startswith(f"{name}.")
|
||||
}
|
||||
to_return.update(
|
||||
{f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()}
|
||||
)
|
||||
|
||||
# DEAL WITH EMBEDDINGS
|
||||
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
|
||||
@ -343,14 +352,9 @@ def set_peft_model_state_dict(
|
||||
# `modules_to_save.{adapter_name}.` prefix. This prefix must be restored when loading the model from the
|
||||
# saved state dict which is why we fetch a load key map from the wrapper.
|
||||
key_map = module.adapter_state_dict_load_map(adapter_name)
|
||||
|
||||
for k in module.adapter_state_dict(adapter_name):
|
||||
# each saved state dict is adapter specific, i.e. does not contain the adapter name
|
||||
# but the loaded state dict does include adapter names since we can have multiple.
|
||||
k_no_adapter = k.replace(f".{adapter_name}", "")
|
||||
|
||||
store_key = f"{name}.{key_map.get(k, k)}"
|
||||
lookup_key = f"{name}.{k_no_adapter}"
|
||||
for k in key_map:
|
||||
lookup_key = f"{name}.{k}"
|
||||
store_key = f"{name}.{key_map[k]}"
|
||||
|
||||
state_dict[store_key] = peft_model_state_dict[lookup_key]
|
||||
|
||||
|
Reference in New Issue
Block a user