mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
feat: add ignored_params support for fsdp2 (#3731)
* feat: add ignored_params support for fsdp2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * feat: add ignored_params support for fsdp2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * feat: add ignored_params support for fsdp2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * feat: add ignored_params support for fsdp2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * test: update testcase for fsdp2 ignored_params Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: add defensive use of ignored params Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: styling errors Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
committed by
GitHub
parent
23cf4ef8a3
commit
a7d6f28f99
@ -1564,7 +1564,7 @@ class FullyShardedDataParallelPlugin:
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
|
||||
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
|
||||
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
|
||||
using regex fullmatch.
|
||||
using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used.
|
||||
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
|
||||
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
|
||||
`sharded_state_dict`.
|
||||
@ -1948,7 +1948,12 @@ class FullyShardedDataParallelPlugin:
|
||||
# Create a function that will be used to initialize the parameters of the model
|
||||
# when using `sync_module_states`
|
||||
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
|
||||
|
||||
if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None:
|
||||
_fsdp2_warnings.add(
|
||||
"FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0"
|
||||
"Setting ignored_modules to None."
|
||||
)
|
||||
self.ignored_modules = None
|
||||
# Single warning for all deprecation warnings due to FSDP2 conversion
|
||||
if _fsdp2_warnings:
|
||||
logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings))
|
||||
|
@ -14,12 +14,14 @@
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -629,6 +631,10 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
|
||||
}
|
||||
if fsdp2_plugin.ignored_modules is not None:
|
||||
fsdp2_kwargs["ignored_params"] = get_parameters_from_modules(
|
||||
fsdp2_plugin.ignored_modules, model, accelerator.device
|
||||
)
|
||||
|
||||
model_has_params4bit = False
|
||||
for name, param in model.named_parameters():
|
||||
@ -791,3 +797,31 @@ def fsdp2_canonicalize_names(named_params: dict) -> dict:
|
||||
}
|
||||
named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
|
||||
return named_params
|
||||
|
||||
|
||||
def get_parameters_from_modules(
|
||||
modules: Union[Iterable[torch.nn.Module], str], model, device
|
||||
) -> set[torch.nn.Parameter]:
|
||||
"""Converts modules to parameters where modules can be a string or list of torch.nn.Module
|
||||
|
||||
Args:
|
||||
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
|
||||
|
||||
Returns:
|
||||
`List[torch.nn.Parameter]`: List of parameters
|
||||
"""
|
||||
if modules is None:
|
||||
return None
|
||||
parameters = []
|
||||
# code taken from accelerate while preparing kwargs for FSDP
|
||||
if isinstance(modules, str):
|
||||
reg = re.compile(modules)
|
||||
mapped_modules = []
|
||||
for name, module in model.named_modules():
|
||||
if reg.fullmatch(name):
|
||||
module.to(device)
|
||||
mapped_modules.append(module)
|
||||
modules = mapped_modules
|
||||
for module in modules:
|
||||
parameters.extend(list(module.parameters()))
|
||||
return set(parameters)
|
||||
|
@ -400,18 +400,19 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
|
||||
def test_ignored_modules_regex(self):
|
||||
# Check that FSDP's ignored_modules can be a string, in which case it is treated as a regex
|
||||
if self.current_fsdp_version != 1:
|
||||
self.skipTest("ignored_modules only relevant for FSDP1")
|
||||
|
||||
env = self.fsdp_envs[1].copy()
|
||||
env["FSDP_IGNORED_MODULES"] = ".*\\.q_proj$"
|
||||
with patch_environment(**env):
|
||||
accelerator = Accelerator()
|
||||
model = AutoModel.from_pretrained(LLAMA_TESTING)
|
||||
# model has 2 layers
|
||||
layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}
|
||||
model = accelerator.prepare(model)
|
||||
assert model._ignored_modules == layers_to_ignore
|
||||
if self.current_fsdp_version == 1:
|
||||
# model has 2 layers
|
||||
layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}
|
||||
assert model._ignored_modules == layers_to_ignore
|
||||
else:
|
||||
params_to_ignore = {model.layers[0].self_attn.q_proj.weight, model.layers[1].self_attn.q_proj.weight}
|
||||
assert model._ignored_params == params_to_ignore
|
||||
|
||||
|
||||
@require_fsdp2
|
||||
|
Reference in New Issue
Block a user