feat: add ignored_params support for fsdp2 (#3731)

* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* test: update testcase for fsdp2 ignored_params

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: add defensive use of ignored params

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: styling errors

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
Mehant Kammakomati
2025-08-18 18:01:19 +05:30
committed by GitHub
parent 23cf4ef8a3
commit a7d6f28f99
3 changed files with 49 additions and 9 deletions

View File

@ -1564,7 +1564,7 @@ class FullyShardedDataParallelPlugin:
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
using regex fullmatch.
using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used.
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
`sharded_state_dict`.
@ -1948,7 +1948,12 @@ class FullyShardedDataParallelPlugin:
# Create a function that will be used to initialize the parameters of the model
# when using `sync_module_states`
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None:
_fsdp2_warnings.add(
"FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0"
"Setting ignored_modules to None."
)
self.ignored_modules = None
# Single warning for all deprecation warnings due to FSDP2 conversion
if _fsdp2_warnings:
logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings))

View File

@ -14,12 +14,14 @@
import copy
import functools
import os
import re
import shutil
import warnings
from collections import defaultdict
from collections.abc import Iterable
from contextlib import nullcontext
from pathlib import Path
from typing import Callable
from typing import Callable, Union
import torch
@ -629,6 +631,10 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
}
if fsdp2_plugin.ignored_modules is not None:
fsdp2_kwargs["ignored_params"] = get_parameters_from_modules(
fsdp2_plugin.ignored_modules, model, accelerator.device
)
model_has_params4bit = False
for name, param in model.named_parameters():
@ -791,3 +797,31 @@ def fsdp2_canonicalize_names(named_params: dict) -> dict:
}
named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
return named_params
def get_parameters_from_modules(
modules: Union[Iterable[torch.nn.Module], str], model, device
) -> set[torch.nn.Parameter]:
"""Converts modules to parameters where modules can be a string or list of torch.nn.Module
Args:
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
Returns:
`List[torch.nn.Parameter]`: List of parameters
"""
if modules is None:
return None
parameters = []
# code taken from accelerate while preparing kwargs for FSDP
if isinstance(modules, str):
reg = re.compile(modules)
mapped_modules = []
for name, module in model.named_modules():
if reg.fullmatch(name):
module.to(device)
mapped_modules.append(module)
modules = mapped_modules
for module in modules:
parameters.extend(list(module.parameters()))
return set(parameters)

View File

@ -400,18 +400,19 @@ class FSDPPluginIntegration(AccelerateTestCase):
def test_ignored_modules_regex(self):
# Check that FSDP's ignored_modules can be a string, in which case it is treated as a regex
if self.current_fsdp_version != 1:
self.skipTest("ignored_modules only relevant for FSDP1")
env = self.fsdp_envs[1].copy()
env["FSDP_IGNORED_MODULES"] = ".*\\.q_proj$"
with patch_environment(**env):
accelerator = Accelerator()
model = AutoModel.from_pretrained(LLAMA_TESTING)
# model has 2 layers
layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}
model = accelerator.prepare(model)
assert model._ignored_modules == layers_to_ignore
if self.current_fsdp_version == 1:
# model has 2 layers
layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}
assert model._ignored_modules == layers_to_ignore
else:
params_to_ignore = {model.layers[0].self_attn.q_proj.weight, model.layers[1].self_attn.q_proj.weight}
assert model._ignored_params == params_to_ignore
@require_fsdp2