mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Training PEFT models with new tokens being added to the embedding layers and tokenizer (#1147)
* add support for saving base layers weights along with adapter weights * Update save_and_load.py * Add an example showing the usage of the added feature * refactor the functionality * fix * refactoring code 1. Add `is_embedding_layer_resized` parameter to `save_pretrained` 2. Fix the deduplication in README when adding PEFT details. 3. `save_pretrained` should only save the model when `is_main_process=True` which is one of the parameters of `save_pretrained`. * update example * fix the model card * fix model card * 😅 * fix model card * automate setting `is_embedding_layer_resized` * nits * Update peft_lora_clm_with_additional_tokens.ipynb * add test * fix tests * maybe fixes the issue? * address comments Co-Authored-By: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f0fb9516d8
commit
8298f1a366
File diff suppressed because it is too large
Load Diff
@ -159,6 +159,8 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
save_directory: str,
|
||||
safe_serialization: bool = True,
|
||||
selected_adapters: Optional[List[str]] = None,
|
||||
save_embedding_layers: Union[str, bool] = "auto",
|
||||
is_main_process: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
r"""
|
||||
@ -172,6 +174,14 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
exist).
|
||||
safe_serialization (`bool`, *optional*):
|
||||
Whether to save the adapter files in safetensors format.
|
||||
selected_adapters (`list(str)`, *optional*):
|
||||
A list of adapters to be saved. If `None`, will default to all adapters.
|
||||
save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`):
|
||||
If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common
|
||||
embedding layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available.
|
||||
Based on it sets the boolean flag. This only works for 🤗 transformers models.
|
||||
is_main_process (`bool`, *optional*):
|
||||
Whether the process calling this is the main process or not. Will default to `True`.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments passed along to the `push_to_hub` method.
|
||||
"""
|
||||
@ -190,19 +200,23 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
f" {list(self.peft_config.keys())} - got {selected_adapters}."
|
||||
)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
self.create_or_update_model_card(save_directory)
|
||||
if is_main_process:
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
self.create_or_update_model_card(save_directory)
|
||||
|
||||
for adapter_name in selected_adapters:
|
||||
peft_config = self.peft_config[adapter_name]
|
||||
# save only the trainable weights
|
||||
output_state_dict = get_peft_model_state_dict(
|
||||
self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name
|
||||
self,
|
||||
state_dict=kwargs.get("state_dict", None),
|
||||
adapter_name=adapter_name,
|
||||
save_embedding_layers=save_embedding_layers,
|
||||
)
|
||||
output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if safe_serialization:
|
||||
if is_main_process and safe_serialization:
|
||||
# Section copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2111-L2134
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
# We're going to remove aliases before saving
|
||||
@ -230,7 +244,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
elif is_main_process:
|
||||
torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
|
||||
# save the config and change the inference mode to `True`
|
||||
@ -257,7 +271,8 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
else:
|
||||
auto_mapping_dict = None
|
||||
|
||||
peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict)
|
||||
if is_main_process:
|
||||
peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict)
|
||||
peft_config.inference_mode = inference_mode
|
||||
|
||||
@classmethod
|
||||
@ -721,24 +736,27 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
if hasattr(self.config, "quantization_config"):
|
||||
quantization_config = self.config.quantization_config.to_dict()
|
||||
training_config_text = ""
|
||||
quantization_prefix = "The following `bitsandbytes` quantization config was used during training:"
|
||||
# Adds quantization information if it was used
|
||||
if quantization_config is not None:
|
||||
training_config_text += "\nThe following `bitsandbytes` quantization config was used during training:\n"
|
||||
training_config_text += f"\n{quantization_prefix}\n"
|
||||
training_config_text += "\n".join([f"- {name}: {value}" for name, value in quantization_config.items()])
|
||||
training_config_text += "\n"
|
||||
|
||||
training_procedure_heading = "## Training procedure\n"
|
||||
if training_procedure_heading in lines:
|
||||
lines.insert(lines.index(training_procedure_heading) + 2, training_config_text)
|
||||
else:
|
||||
lines.append(f"{training_procedure_heading}\n{training_config_text}")
|
||||
training_procedure_heading = "## Training procedure"
|
||||
if quantization_prefix not in lines and bool(training_config_text):
|
||||
if training_procedure_heading in lines:
|
||||
lines.insert(lines.index(training_procedure_heading) + 2, training_config_text)
|
||||
else:
|
||||
lines.append(f"{training_procedure_heading}\n{training_config_text}")
|
||||
|
||||
# Adds peft version
|
||||
framework_block_heading = "### Framework versions\n"
|
||||
if framework_block_heading in lines:
|
||||
lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}\n")
|
||||
else:
|
||||
lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}\n")
|
||||
framework_block_heading = "### Framework versions"
|
||||
if f"- PEFT {__version__}" not in lines:
|
||||
if framework_block_heading in lines:
|
||||
lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}")
|
||||
else:
|
||||
lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}")
|
||||
|
||||
card.text = "\n".join(lines)
|
||||
card.save(filename)
|
||||
|
@ -583,3 +583,4 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
|
||||
WEIGHTS_NAME = "adapter_model.bin"
|
||||
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||
CONFIG_NAME = "adapter_config.json"
|
||||
EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"]
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -20,11 +21,26 @@ from huggingface_hub import file_exists, hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
from .other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device
|
||||
from .other import EMBEDDING_LAYER_NAMES, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device
|
||||
from .peft_types import PeftType
|
||||
|
||||
|
||||
def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", unwrap_compiled=False):
|
||||
def has_valid_embedding_base_layer(layer):
|
||||
"""Check if the layer has an embedding base layer"""
|
||||
return hasattr(layer, "base_layer") and isinstance(layer.base_layer, (torch.nn.Linear, torch.nn.Embedding))
|
||||
|
||||
|
||||
def get_embedding_layer_name(model, layer, is_prompt_learning):
|
||||
"""Get the name of the embedding module for a given layer."""
|
||||
for name, module in model.named_modules():
|
||||
if (is_prompt_learning and module == layer) or module == layer.base_layer:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def get_peft_model_state_dict(
|
||||
model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto"
|
||||
):
|
||||
"""
|
||||
Get the state dict of the Peft model.
|
||||
|
||||
@ -37,6 +53,10 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", un
|
||||
The name of the adapter whose state dict should be returned.
|
||||
unwrap_compiled (`bool`, *optional*, defaults to `False`):
|
||||
Whether to unwrap the model if torch.compile was used.
|
||||
save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`):
|
||||
If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common embedding
|
||||
layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. Based on it
|
||||
sets the boolean flag. This only works for 🤗 transformers models.
|
||||
"""
|
||||
if unwrap_compiled:
|
||||
model = getattr(model, "_orig_mod", model)
|
||||
@ -100,6 +120,27 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", un
|
||||
if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
|
||||
to_return[key.replace("modules_to_save.", "")] = value
|
||||
|
||||
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
|
||||
if (
|
||||
save_embedding_layers == "auto"
|
||||
and hasattr(config, "target_modules")
|
||||
and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
|
||||
):
|
||||
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
|
||||
save_embedding_layers = True
|
||||
elif save_embedding_layers == "auto":
|
||||
save_embedding_layers = False
|
||||
|
||||
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
|
||||
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
|
||||
if config.is_prompt_learning or has_valid_embedding_base_layer(layer):
|
||||
# support from version >= 0.6.2
|
||||
embedding_module_name = get_embedding_layer_name(model, layer, config.is_prompt_learning)
|
||||
if embedding_module_name:
|
||||
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
|
||||
elif save_embedding_layers:
|
||||
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
|
||||
|
||||
to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
|
@ -333,6 +333,33 @@ class ModelEmbConv1D(nn.Module):
|
||||
return X
|
||||
|
||||
|
||||
class ModelEmbWithEmbeddingUtils(nn.Module):
|
||||
# Adds `get_input_embeddings` and `get_output_embeddings` methods to mimic 🤗 transformers models
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(100, 5)
|
||||
self.conv1d = Conv1D(1, 5)
|
||||
self.relu = nn.ReLU()
|
||||
self.flat = nn.Flatten()
|
||||
self.lin0 = nn.Linear(10, 2)
|
||||
self.sm = nn.LogSoftmax(dim=-1)
|
||||
|
||||
def forward(self, X):
|
||||
X = self.embed_tokens(X)
|
||||
X = self.conv1d(X)
|
||||
X = self.relu(X)
|
||||
X = self.flat(X)
|
||||
X = self.lin0(X)
|
||||
X = self.sm(X)
|
||||
return X
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return None
|
||||
|
||||
|
||||
class ModelConv2D(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -750,6 +777,55 @@ class PeftCustomModelTester(unittest.TestCase, PeftCommonTester):
|
||||
# rough check that the model card is pre-filled
|
||||
self.assertGreater(len(model_card), 1000)
|
||||
|
||||
@parameterized.expand(["auto", True, False])
|
||||
def test_targeting_lora_to_embedding_layer(self, save_embedding_layers):
|
||||
model = ModelEmbWithEmbeddingUtils()
|
||||
config = LoraConfig(target_modules=["embed_tokens", "lin0"], init_lora_weights=False)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
if save_embedding_layers == "auto":
|
||||
# assert warning
|
||||
msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`."
|
||||
with self.assertWarns(UserWarning, msg=msg_start):
|
||||
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
|
||||
else:
|
||||
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors"))
|
||||
if save_embedding_layers in ["auto", True]:
|
||||
self.assertTrue("base_model.model.embed_tokens.base_layer.weight" in state_dict)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
model.base_model.model.embed_tokens.base_layer.weight,
|
||||
state_dict["base_model.model.embed_tokens.base_layer.weight"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.assertFalse("base_model.model.embed_tokens.base_layer.weight" in state_dict)
|
||||
del state_dict
|
||||
|
||||
@parameterized.expand(["auto", True, False])
|
||||
def test_targeting_lora_to_embedding_layer_non_transformers(self, save_embedding_layers):
|
||||
model = ModelEmbConv1D()
|
||||
config = LoraConfig(target_modules=["emb", "lin0"], init_lora_weights=False)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
if save_embedding_layers is True:
|
||||
# assert warning
|
||||
msg_start = "Could not identify embedding layer(s) because the model is not a 🤗 transformers model."
|
||||
with self.assertWarns(UserWarning, msg=msg_start):
|
||||
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
|
||||
else:
|
||||
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors"))
|
||||
self.assertFalse("base_model.model.emb.base_layer.weight" in state_dict)
|
||||
del state_dict
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
LoraConfig(target_modules=["lin0"], init_lora_weights=False),
|
||||
|
Reference in New Issue
Block a user