ENH: Allow FSDP ignored modules to be regex (#3698)

* ENH: Allow FSDP ignored modules to be regex

Description

For FSDP, there is an option to indicate ignored_modules, which should
be a list of modules are ignored by FSDP. Even though this argument was
supported in accelerate, it was not very usable:

1. Listing all modules can tricky, especially with something like PEFT,
where the whole model is wrapped and thus the module structure changes.
2. When configuring this argument, accelerate takes a detour via
environment variables. These can only be strings. Therefore, passing a
list of modules is not feasible.

Moreover, I noticed that the environment variable for ignored_modules
was not even set, so configuring this argument didn't even work.

Status

This PR is lacking tests. I would be happy for pointers on how to add
those.

Context

When using PEFT with LoRA and the target_parameters feature, I ran into
an issue training such a model with FSDP. The only working fix I found
was to ignore the layers targeted by LoRA. However, I could not
configure accelerate to do that. With this PR, it is possible. I could
successfully trained such a PEFT model that targets q_proj and v_proj by
setting fsdp_ignored_modules: '.*\.(q_proj$|v_proj$)'.

* Fix type annotation

* Fix failing test
This commit is contained in:
Benjamin Bossan
2025-08-05 14:23:14 +02:00
committed by GitHub
parent 6640ff415c
commit 24e48f3d20
4 changed files with 21 additions and 3 deletions

View File

@ -1872,6 +1872,17 @@ class Accelerator:
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": self.device,
}
if isinstance(kwargs["ignored_modules"], str):
reg = re.compile(kwargs["ignored_modules"])
ignored = []
for name, module in model.named_modules():
if reg.fullmatch(name):
# ensure that the device for these modules is still set correctly
module.to(self.device)
ignored.append(module)
kwargs["ignored_modules"] = ignored
model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (

View File

@ -1561,8 +1561,9 @@ class FullyShardedDataParallelPlugin:
Whether to offload parameters to CPU. Should be either a `bool` or an instance of
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`):
A list of modules to ignore when wrapping with FSDP.
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
using regex fullmatch.
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
`sharded_state_dict`.
@ -1660,7 +1661,7 @@ class FullyShardedDataParallelPlugin:
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
},
)
ignored_modules: Optional[Iterable[torch.nn.Module]] = field(
ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field(
default=None,
metadata={"help": "A list of modules to ignore when wrapping with FSDP."},
)
@ -1896,6 +1897,9 @@ class FullyShardedDataParallelPlugin:
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
)
if self.ignored_modules is None:
self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None)
if self.cpu_ram_efficient_loading is None:
self.cpu_ram_efficient_loading = (
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1

View File

@ -328,6 +328,8 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
if getattr(args, "fsdp_ignored_modules", None) is not None:
current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
if args.use_megatron_lm:
prefix = "MEGATRON_LM_"

View File

@ -15,6 +15,7 @@ fsdp_config:
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
fsdp_ignored_modules: null
machine_rank: 0
main_training_function: main
mixed_precision: 'no'