Compare commits

...

1 Commits

Author SHA1 Message Date
c42eb227d9 Prepare 0.15.1 patch (#2459)
This release is a patch release to release a fix for #2450
which might result in loss of `modules_to_save` when trained with
deepspeed ZerO stage 3.
2025-03-27 16:33:26 +01:00
4 changed files with 39 additions and 22 deletions

View File

@ -15,7 +15,7 @@
from setuptools import find_packages, setup
VERSION = "0.15.0"
VERSION = "0.15.1"
extras = {}
extras["quality"] = [

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.15.0"
__version__ = "0.15.1"
from .auto import (
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,

View File

@ -434,9 +434,10 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
"""Return a mapping from the key present in disk-loaded state dict
and how it should be represented in the loaded model's state dict.
If a key is not present here, it is assumed to be mapped 1:1.
The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the
ground-truth for which keys are supposed to be loaded from a saved state dict.
"""
return {}
raise NotImplementedError
def unload_and_optionally_merge_module(
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
@ -550,15 +551,24 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
self._active_adapter = adapter_name
def adapter_state_dict_load_map(self, adapter_name):
# The state dict returned by ModulesToSaveWrapper
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.adapter_state_dict(adapter_name)}
def adapter_state_dict(self, adapter_name):
# Maps the module keys as they are in the saved state dict to the in-memory state dict.
# Must contain all keys that are supposed to be loaded.
if adapter_name not in self._adapters:
# In caes of multiple adapters, each bringing their own modules to save, each
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
return {}
return self.modules_to_save[adapter_name].state_dict()
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.modules_to_save[adapter_name].state_dict()}
def adapter_state_dict(self, adapter_name, state_dict):
if adapter_name not in self._adapters:
# In caes of multiple adapters, each bringing their own modules to save, each
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
return {}
return {
k: state_dict[f"modules_to_save.{adapter_name}.{k}"]
for k in self.modules_to_save[adapter_name].state_dict()
}
def unload_and_optionally_merge_module(
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
@ -651,7 +661,12 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
super().update(active_adapter)
def adapter_state_dict(self, adapter_name):
def adapter_state_dict_load_map(self, adapter_name):
if self.token_adapter.tied_adapter:
return {}
return {"token_adapter.trainable_tokens_delta": f"token_adapter.trainable_tokens_delta.{adapter_name}"}
def adapter_state_dict(self, adapter_name, state_dict):
if self.token_adapter.tied_adapter:
# storing of weight-tied layers is not up to us and will be handled by
# transformers. we're just here to keep those layers in sync during training.
@ -659,9 +674,7 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
return {}
return {
f"token_adapter.{k}": v
for k, v in self.token_adapter.state_dict().items()
if k.startswith("trainable_tokens_") and k.endswith(f".{adapter_name}")
f"token_adapter.{k}": state_dict[f"token_adapter.{k}.{adapter_name}"] for k in ["trainable_tokens_delta"]
}
def enable_adapters(self, enabled: bool):

View File

@ -197,7 +197,16 @@ def get_peft_model_state_dict(
# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
for name, module in model.named_modules():
if isinstance(module, AuxiliaryTrainingWrapper):
to_return.update({f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name).items()})
# Compute the module-relative state dict to make it easier for the adapter to fetch the appropriate
# keys that the module thinks need to be saved. We cannot rely on `.state_dict()` internally of the
# module since accelerators like DeepSpeed require special handling which is done for the model
# state dict from above but most likely not in the module itself. See #2450.
module_state_dict = {
k.removeprefix(f"{name}."): v for k, v in state_dict.items() if k.startswith(f"{name}.")
}
to_return.update(
{f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()}
)
# DEAL WITH EMBEDDINGS
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
@ -343,14 +352,9 @@ def set_peft_model_state_dict(
# `modules_to_save.{adapter_name}.` prefix. This prefix must be restored when loading the model from the
# saved state dict which is why we fetch a load key map from the wrapper.
key_map = module.adapter_state_dict_load_map(adapter_name)
for k in module.adapter_state_dict(adapter_name):
# each saved state dict is adapter specific, i.e. does not contain the adapter name
# but the loaded state dict does include adapter names since we can have multiple.
k_no_adapter = k.replace(f".{adapter_name}", "")
store_key = f"{name}.{key_map.get(k, k)}"
lookup_key = f"{name}.{k_no_adapter}"
for k in key_map:
lookup_key = f"{name}.{k}"
store_key = f"{name}.{key_map[k]}"
state_dict[store_key] = peft_model_state_dict[lookup_key]