Lower complexity of get_balanced_memory by adding a set (#3776)

* Lower complexity by adding a set

* Push vibe coded eval script

* Clean
This commit is contained in:
Samuel Barry
2025-09-17 09:30:55 -07:00
committed by GitHub
parent 1b50d93999
commit 409b356f45

View File

@ -1041,7 +1041,8 @@ def get_balanced_memory(
# Compute mean of final modules. In the first dict of module sizes, leaves are the parameters
leaves = get_module_leaves(module_sizes)
module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves}
leaves_set = set(leaves) # Convert to set for O(1) membership testing
module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}
# Once removed, leaves are the final modules.
leaves = get_module_leaves(module_sizes)
mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))