Fix #2450: Revamp adapter_state_dict_* methods (#2456)

`AuxiliaryTrainingWrapper.adapter_state_dict` now utilizes an external state dict for the
computation of the module state dict to avoid problems with DeepSpeed (or FSDP) when dealing
with distributed parameters.

It is not possible to simply wrap everything in `GatheredParameters` context managers since
doing that leads to a deadlock when running on more than one process (reasons unclear).
Since transformers, or more specifically, accelerate already handles state dict fetching
for the whole model, it is more economical to use that state dict and rewrite the methods
that before depended on `state_dict()` calls.
This commit is contained in:
githubnemo
2025-03-27 14:08:10 +01:00
committed by GitHub
parent 911da6f356
commit 7279a9ff2e
2 changed files with 37 additions and 20 deletions

View File

@ -434,9 +434,10 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
"""Return a mapping from the key present in disk-loaded state dict
and how it should be represented in the loaded model's state dict.
If a key is not present here, it is assumed to be mapped 1:1.
The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the
ground-truth for which keys are supposed to be loaded from a saved state dict.
"""
return {}
raise NotImplementedError
def unload_and_optionally_merge_module(
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
@ -550,15 +551,24 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
self._active_adapter = adapter_name
def adapter_state_dict_load_map(self, adapter_name):
# The state dict returned by ModulesToSaveWrapper
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.adapter_state_dict(adapter_name)}
def adapter_state_dict(self, adapter_name):
# Maps the module keys as they are in the saved state dict to the in-memory state dict.
# Must contain all keys that are supposed to be loaded.
if adapter_name not in self._adapters:
# In caes of multiple adapters, each bringing their own modules to save, each
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
return {}
return self.modules_to_save[adapter_name].state_dict()
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.modules_to_save[adapter_name].state_dict()}
def adapter_state_dict(self, adapter_name, state_dict):
if adapter_name not in self._adapters:
# In caes of multiple adapters, each bringing their own modules to save, each
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
return {}
return {
k: state_dict[f"modules_to_save.{adapter_name}.{k}"]
for k in self.modules_to_save[adapter_name].state_dict()
}
def unload_and_optionally_merge_module(
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
@ -651,7 +661,12 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
super().update(active_adapter)
def adapter_state_dict(self, adapter_name):
def adapter_state_dict_load_map(self, adapter_name):
if self.token_adapter.tied_adapter:
return {}
return {"token_adapter.trainable_tokens_delta": f"token_adapter.trainable_tokens_delta.{adapter_name}"}
def adapter_state_dict(self, adapter_name, state_dict):
if self.token_adapter.tied_adapter:
# storing of weight-tied layers is not up to us and will be handled by
# transformers. we're just here to keep those layers in sync during training.
@ -659,9 +674,7 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
return {}
return {
f"token_adapter.{k}": v
for k, v in self.token_adapter.state_dict().items()
if k.startswith("trainable_tokens_") and k.endswith(f".{adapter_name}")
f"token_adapter.{k}": state_dict[f"token_adapter.{k}.{adapter_name}"] for k in ["trainable_tokens_delta"]
}
def enable_adapters(self, enabled: bool):

View File

@ -197,7 +197,16 @@ def get_peft_model_state_dict(
# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
for name, module in model.named_modules():
if isinstance(module, AuxiliaryTrainingWrapper):
to_return.update({f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name).items()})
# Compute the module-relative state dict to make it easier for the adapter to fetch the appropriate
# keys that the module thinks need to be saved. We cannot rely on `.state_dict()` internally of the
# module since accelerators like DeepSpeed require special handling which is done for the model
# state dict from above but most likely not in the module itself. See #2450.
module_state_dict = {
k.removeprefix(f"{name}."): v for k, v in state_dict.items() if k.startswith(f"{name}.")
}
to_return.update(
{f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()}
)
# DEAL WITH EMBEDDINGS
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
@ -343,14 +352,9 @@ def set_peft_model_state_dict(
# `modules_to_save.{adapter_name}.` prefix. This prefix must be restored when loading the model from the
# saved state dict which is why we fetch a load key map from the wrapper.
key_map = module.adapter_state_dict_load_map(adapter_name)
for k in module.adapter_state_dict(adapter_name):
# each saved state dict is adapter specific, i.e. does not contain the adapter name
# but the loaded state dict does include adapter names since we can have multiple.
k_no_adapter = k.replace(f".{adapter_name}", "")
store_key = f"{name}.{key_map.get(k, k)}"
lookup_key = f"{name}.{k_no_adapter}"
for k in key_map:
lookup_key = f"{name}.{k}"
store_key = f"{name}.{key_map[k]}"
state_dict[store_key] = peft_model_state_dict[lookup_key]