mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Remove deprecated FindTiedParametersResult (#3786)
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
This commit is contained in:
@ -478,23 +478,6 @@ def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: b
|
||||
return non_persistent_buffers_set
|
||||
|
||||
|
||||
class FindTiedParametersResult(list):
|
||||
"""
|
||||
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
|
||||
a list or on the `values` method as in the future this will be removed.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def values(self):
|
||||
warnings.warn(
|
||||
"The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ",
|
||||
FutureWarning,
|
||||
)
|
||||
return sum([x[1:] for x in self], [])
|
||||
|
||||
|
||||
def check_tied_parameters_in_config(model: nn.Module):
|
||||
"""
|
||||
Check if there is any indication in the given model that some weights should be tied.
|
||||
@ -568,7 +551,7 @@ def check_tied_parameters_on_same_device(tied_params, device_map):
|
||||
)
|
||||
|
||||
|
||||
def find_tied_parameters(model: torch.nn.Module, **kwargs):
|
||||
def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]:
|
||||
"""
|
||||
Find the tied parameters in a given model.
|
||||
|
||||
@ -620,7 +603,7 @@ def find_tied_parameters(model: torch.nn.Module, **kwargs):
|
||||
tied_param_groups[param_name] = []
|
||||
tied_param_groups[param_name].append(tied_param_name)
|
||||
|
||||
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()])
|
||||
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
|
||||
|
||||
|
||||
def retie_parameters(model, tied_params):
|
||||
|
Reference in New Issue
Block a user