Remove deprecated FindTiedParametersResult (#3786)

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
This commit is contained in:
Yuanyuan Chen
2025-09-19 21:00:44 +08:00
committed by GitHub
parent a737437c8a
commit 14383311c2

View File

@ -478,23 +478,6 @@ def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: b
return non_persistent_buffers_set
class FindTiedParametersResult(list):
"""
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
a list or on the `values` method as in the future this will be removed.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def values(self):
warnings.warn(
"The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ",
FutureWarning,
)
return sum([x[1:] for x in self], [])
def check_tied_parameters_in_config(model: nn.Module):
"""
Check if there is any indication in the given model that some weights should be tied.
@ -568,7 +551,7 @@ def check_tied_parameters_on_same_device(tied_params, device_map):
)
def find_tied_parameters(model: torch.nn.Module, **kwargs):
def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]:
"""
Find the tied parameters in a given model.
@ -620,7 +603,7 @@ def find_tied_parameters(model: torch.nn.Module, **kwargs):
tied_param_groups[param_name] = []
tied_param_groups[param_name].append(tied_param_name)
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()])
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
def retie_parameters(model, tied_params):