From 5998f8625b8dfde9253c241233ff13bc2c18635d Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Tue, 14 Oct 2025 17:41:32 +0530 Subject: [PATCH] refactor: nit change for get_parameters_from_modules (code debt) (#3815) * refactor: nit change for get_parameters_from_modules Signed-off-by: Mehant Kammakomati * fix: quality check Signed-off-by: Mehant Kammakomati --------- Signed-off-by: Mehant Kammakomati --- src/accelerate/utils/fsdp_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/accelerate/utils/fsdp_utils.py b/src/accelerate/utils/fsdp_utils.py index d5677e6e..3803048c 100644 --- a/src/accelerate/utils/fsdp_utils.py +++ b/src/accelerate/utils/fsdp_utils.py @@ -630,11 +630,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), "mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None, + "ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device), } - if fsdp2_plugin.ignored_modules is not None: - fsdp2_kwargs["ignored_params"] = get_parameters_from_modules( - fsdp2_plugin.ignored_modules, model, accelerator.device - ) model_has_params4bit = False for name, param in model.named_parameters(): @@ -808,10 +805,10 @@ def get_parameters_from_modules( modules (`Union[Iterable[torch.nn.Module], str]`): List of modules Returns: - `List[torch.nn.Parameter]`: List of parameters + `set[torch.nn.Parameter]`: List of parameters """ if modules is None: - return None + return set() parameters = [] # code taken from accelerate while preparing kwargs for FSDP if isinstance(modules, str):