mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
refactor: nit change for get_parameters_from_modules (code debt) (#3815)
* refactor: nit change for get_parameters_from_modules Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: quality check Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
committed by
GitHub
parent
f0313a64a2
commit
5998f8625b
@ -630,11 +630,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||||
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
|
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
|
||||||
|
"ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device),
|
||||||
}
|
}
|
||||||
if fsdp2_plugin.ignored_modules is not None:
|
|
||||||
fsdp2_kwargs["ignored_params"] = get_parameters_from_modules(
|
|
||||||
fsdp2_plugin.ignored_modules, model, accelerator.device
|
|
||||||
)
|
|
||||||
|
|
||||||
model_has_params4bit = False
|
model_has_params4bit = False
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
@ -808,10 +805,10 @@ def get_parameters_from_modules(
|
|||||||
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
|
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`List[torch.nn.Parameter]`: List of parameters
|
`set[torch.nn.Parameter]`: List of parameters
|
||||||
"""
|
"""
|
||||||
if modules is None:
|
if modules is None:
|
||||||
return None
|
return set()
|
||||||
parameters = []
|
parameters = []
|
||||||
# code taken from accelerate while preparing kwargs for FSDP
|
# code taken from accelerate while preparing kwargs for FSDP
|
||||||
if isinstance(modules, str):
|
if isinstance(modules, str):
|
||||||
|
Reference in New Issue
Block a user