mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Lower complexity of get_balanced_memory by adding a set (#3776)
* Lower complexity by adding a set * Push vibe coded eval script * Clean
This commit is contained in:
@ -1041,7 +1041,8 @@ def get_balanced_memory(
|
||||
|
||||
# Compute mean of final modules. In the first dict of module sizes, leaves are the parameters
|
||||
leaves = get_module_leaves(module_sizes)
|
||||
module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves}
|
||||
leaves_set = set(leaves) # Convert to set for O(1) membership testing
|
||||
module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}
|
||||
# Once removed, leaves are the final modules.
|
||||
leaves = get_module_leaves(module_sizes)
|
||||
mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))
|
||||
|
Reference in New Issue
Block a user