mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
ENH: Allow FSDP ignored modules to be regex (#3698)
* ENH: Allow FSDP ignored modules to be regex Description For FSDP, there is an option to indicate ignored_modules, which should be a list of modules are ignored by FSDP. Even though this argument was supported in accelerate, it was not very usable: 1. Listing all modules can tricky, especially with something like PEFT, where the whole model is wrapped and thus the module structure changes. 2. When configuring this argument, accelerate takes a detour via environment variables. These can only be strings. Therefore, passing a list of modules is not feasible. Moreover, I noticed that the environment variable for ignored_modules was not even set, so configuring this argument didn't even work. Status This PR is lacking tests. I would be happy for pointers on how to add those. Context When using PEFT with LoRA and the target_parameters feature, I ran into an issue training such a model with FSDP. The only working fix I found was to ignore the layers targeted by LoRA. However, I could not configure accelerate to do that. With this PR, it is possible. I could successfully trained such a PEFT model that targets q_proj and v_proj by setting fsdp_ignored_modules: '.*\.(q_proj$|v_proj$)'. * Fix type annotation * Fix failing test
This commit is contained in:
@ -1872,6 +1872,17 @@ class Accelerator:
|
||||
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
|
||||
"device_id": self.device,
|
||||
}
|
||||
|
||||
if isinstance(kwargs["ignored_modules"], str):
|
||||
reg = re.compile(kwargs["ignored_modules"])
|
||||
ignored = []
|
||||
for name, module in model.named_modules():
|
||||
if reg.fullmatch(name):
|
||||
# ensure that the device for these modules is still set correctly
|
||||
module.to(self.device)
|
||||
ignored.append(module)
|
||||
kwargs["ignored_modules"] = ignored
|
||||
|
||||
model = FSDP(model, **kwargs)
|
||||
if fsdp_plugin.activation_checkpointing:
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
|
@ -1561,8 +1561,9 @@ class FullyShardedDataParallelPlugin:
|
||||
Whether to offload parameters to CPU. Should be either a `bool` or an instance of
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
|
||||
ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`):
|
||||
A list of modules to ignore when wrapping with FSDP.
|
||||
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
|
||||
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
|
||||
using regex fullmatch.
|
||||
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
|
||||
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
|
||||
`sharded_state_dict`.
|
||||
@ -1660,7 +1661,7 @@ class FullyShardedDataParallelPlugin:
|
||||
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
|
||||
},
|
||||
)
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = field(
|
||||
ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "A list of modules to ignore when wrapping with FSDP."},
|
||||
)
|
||||
@ -1896,6 +1897,9 @@ class FullyShardedDataParallelPlugin:
|
||||
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
|
||||
)
|
||||
|
||||
if self.ignored_modules is None:
|
||||
self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None)
|
||||
|
||||
if self.cpu_ram_efficient_loading is None:
|
||||
self.cpu_ram_efficient_loading = (
|
||||
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
|
||||
|
@ -328,6 +328,8 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
|
||||
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
|
||||
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
|
||||
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
|
||||
if getattr(args, "fsdp_ignored_modules", None) is not None:
|
||||
current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
|
||||
|
||||
if args.use_megatron_lm:
|
||||
prefix = "MEGATRON_LM_"
|
||||
|
@ -15,6 +15,7 @@ fsdp_config:
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_transformer_layer_cls_to_wrap: BertLayer
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_ignored_modules: null
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
|
Reference in New Issue
Block a user