FIX Correctly determine no_split_modules (#2570)

See discussion in https://github.com/huggingface/transformers/pull/38141
for context.

In the PEFT fsdp_auto_wrap policy, we determine the _no_split_modules.
However, this currently neglects to visit the children of the model,
which can be required for some architectures. This PR fixes that.

Note that the _get_no_split_modules function is largely copied from
transformers. One change is that it doesn't take the device_map
argument. That argument is used in transformers inside an error message
but not for the logic proper. I think it's safe to remove.

Morever, I made an unrelated change to fsdp_auto_wrap_policy, namely
making local imports global (there was no reason for them to be local).
This commit is contained in:
Benjamin Bossan
2025-06-16 17:21:06 +02:00
committed by GitHub
parent 759bb70ace
commit fc254e39d9
2 changed files with 55 additions and 10 deletions

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import copy
import functools
import inspect
import os
import re
@ -24,12 +25,14 @@ from typing import Any, Optional, Union
import accelerate
import torch
from accelerate import FullyShardedDataParallelPlugin
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import is_npu_available, is_xpu_available
from huggingface_hub import file_exists
from huggingface_hub.errors import EntryNotFoundError, HFValidationError
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from transformers import PreTrainedModel
from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available
from .constants import (
@ -942,12 +945,33 @@ def _prepare_prompt_learning_config(peft_config, model_config):
return peft_config
def _get_no_split_modules(model) -> set[str]:
"""
Get the modules of the model that should not be split when using device_map. We iterate through the modules to get
the underlying `_no_split_modules`.
Returns:
`List[str]`: List of modules that should not be split
"""
# After discussion in https://github.com/huggingface/transformers/pull/38141, based on:
# https://github.com/huggingface/transformers/blob/1e921a3a9cea92b383ca4b0484ee45596bbdadc3/src/transformers/modeling_utils.py#L2677-L2704
_no_split_modules: set[str] = set()
if not hasattr(model, "_no_split_modules"):
return _no_split_modules
modules_to_check = [model]
while len(modules_to_check) > 0:
module = modules_to_check.pop(-1)
# if the module does not appear in _no_split_modules, we also check the children
if module.__class__.__name__ not in _no_split_modules:
if isinstance(module, PreTrainedModel):
if module._no_split_modules is not None:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
modules_to_check += list(module.children())
return _no_split_modules
def fsdp_auto_wrap_policy(model):
import functools
import os
from accelerate import FullyShardedDataParallelPlugin
if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"):
get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name
else:
@ -956,9 +980,7 @@ def fsdp_auto_wrap_policy(model):
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
default_transformer_cls_names_to_wrap = (
",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else ""
)
default_transformer_cls_names_to_wrap = ",".join(_get_no_split_modules(model))
transformer_cls_names_to_wrap = os.environ.get(
"FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap
).split(",")

View File

@ -17,10 +17,10 @@ import copy
import pytest
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, LlavaForConditionalGeneration
from peft import LoraConfig, PeftModel, VeraConfig, get_peft_model
from peft.utils.other import ModulesToSaveWrapper
from peft.utils.other import ModulesToSaveWrapper, _get_no_split_modules
class ModelWithModuleDict(nn.Module):
@ -507,3 +507,26 @@ class TestAdapterTargeting:
}
assert adapter_invariant_keys1 == adapter_invariant_keys2
class TestGetNoSplitModules:
# Ensure that children are considered when determining _no_split_modules
# see https://github.com/huggingface/transformers/pull/38141
def test_get_no_split_modules_simple(self):
# choose a model where recursively visiting children is *not* required
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
assert model._no_split_modules == ["OPTDecoderLayer"]
no_split_modules = _get_no_split_modules(model)
assert no_split_modules == {"OPTDecoderLayer"}
def test_get_no_split_modules_recursive(self):
# choose a model where recursively visiting children is required
model_id = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
model = LlavaForConditionalGeneration.from_pretrained(model_id)
# sanity check: just visiting the model itself is not enough:
assert model._no_split_modules == []
no_split_modules = _get_no_split_modules(model)
assert no_split_modules == {"CLIPEncoderLayer", "LlamaDecoderLayer"}