mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
FIX Correctly determine no_split_modules (#2570)
See discussion in https://github.com/huggingface/transformers/pull/38141 for context. In the PEFT fsdp_auto_wrap policy, we determine the _no_split_modules. However, this currently neglects to visit the children of the model, which can be required for some architectures. This PR fixes that. Note that the _get_no_split_modules function is largely copied from transformers. One change is that it doesn't take the device_map argument. That argument is used in transformers inside an error message but not for the logic proper. I think it's safe to remove. Morever, I made an unrelated change to fsdp_auto_wrap_policy, namely making local imports global (there was no reason for them to be local).
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@ -24,12 +25,14 @@ from typing import Any, Optional, Union
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from accelerate import FullyShardedDataParallelPlugin
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.utils import is_npu_available, is_xpu_available
|
||||
from huggingface_hub import file_exists
|
||||
from huggingface_hub.errors import EntryNotFoundError, HFValidationError
|
||||
from packaging import version
|
||||
from safetensors.torch import storage_ptr, storage_size
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available
|
||||
from .constants import (
|
||||
@ -942,12 +945,33 @@ def _prepare_prompt_learning_config(peft_config, model_config):
|
||||
return peft_config
|
||||
|
||||
|
||||
def _get_no_split_modules(model) -> set[str]:
|
||||
"""
|
||||
Get the modules of the model that should not be split when using device_map. We iterate through the modules to get
|
||||
the underlying `_no_split_modules`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: List of modules that should not be split
|
||||
"""
|
||||
# After discussion in https://github.com/huggingface/transformers/pull/38141, based on:
|
||||
# https://github.com/huggingface/transformers/blob/1e921a3a9cea92b383ca4b0484ee45596bbdadc3/src/transformers/modeling_utils.py#L2677-L2704
|
||||
_no_split_modules: set[str] = set()
|
||||
if not hasattr(model, "_no_split_modules"):
|
||||
return _no_split_modules
|
||||
|
||||
modules_to_check = [model]
|
||||
while len(modules_to_check) > 0:
|
||||
module = modules_to_check.pop(-1)
|
||||
# if the module does not appear in _no_split_modules, we also check the children
|
||||
if module.__class__.__name__ not in _no_split_modules:
|
||||
if isinstance(module, PreTrainedModel):
|
||||
if module._no_split_modules is not None:
|
||||
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
||||
modules_to_check += list(module.children())
|
||||
return _no_split_modules
|
||||
|
||||
|
||||
def fsdp_auto_wrap_policy(model):
|
||||
import functools
|
||||
import os
|
||||
|
||||
from accelerate import FullyShardedDataParallelPlugin
|
||||
|
||||
if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"):
|
||||
get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name
|
||||
else:
|
||||
@ -956,9 +980,7 @@ def fsdp_auto_wrap_policy(model):
|
||||
|
||||
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
|
||||
|
||||
default_transformer_cls_names_to_wrap = (
|
||||
",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else ""
|
||||
)
|
||||
default_transformer_cls_names_to_wrap = ",".join(_get_no_split_modules(model))
|
||||
transformer_cls_names_to_wrap = os.environ.get(
|
||||
"FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap
|
||||
).split(",")
|
||||
|
@ -17,10 +17,10 @@ import copy
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, LlavaForConditionalGeneration
|
||||
|
||||
from peft import LoraConfig, PeftModel, VeraConfig, get_peft_model
|
||||
from peft.utils.other import ModulesToSaveWrapper
|
||||
from peft.utils.other import ModulesToSaveWrapper, _get_no_split_modules
|
||||
|
||||
|
||||
class ModelWithModuleDict(nn.Module):
|
||||
@ -507,3 +507,26 @@ class TestAdapterTargeting:
|
||||
}
|
||||
|
||||
assert adapter_invariant_keys1 == adapter_invariant_keys2
|
||||
|
||||
|
||||
class TestGetNoSplitModules:
|
||||
# Ensure that children are considered when determining _no_split_modules
|
||||
# see https://github.com/huggingface/transformers/pull/38141
|
||||
|
||||
def test_get_no_split_modules_simple(self):
|
||||
# choose a model where recursively visiting children is *not* required
|
||||
model_id = "facebook/opt-125m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
assert model._no_split_modules == ["OPTDecoderLayer"]
|
||||
no_split_modules = _get_no_split_modules(model)
|
||||
assert no_split_modules == {"OPTDecoderLayer"}
|
||||
|
||||
def test_get_no_split_modules_recursive(self):
|
||||
# choose a model where recursively visiting children is required
|
||||
model_id = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id)
|
||||
# sanity check: just visiting the model itself is not enough:
|
||||
assert model._no_split_modules == []
|
||||
|
||||
no_split_modules = _get_no_split_modules(model)
|
||||
assert no_split_modules == {"CLIPEncoderLayer", "LlamaDecoderLayer"}
|
||||
|
Reference in New Issue
Block a user