[aoti] package weights to disk and dedup (#155241)

We package the weights and save them in `data/weights/` (`WEIGHTS_DIR`). In addition, we store a `weights_config.json` in the model folder for each model to specify which weight file corresponding to which weight name.

Models can share weights. We dedup the weights based on their underlying storage (`tensor.untyped_storate()`).

- Use `"aot_inductor.package_constants_on_disk": True` config to produce the `Weights` in aot_compile
- If we see `Weights` in aoti_files, we'll automatically package them to disk
- `"aot_inductor.package_constants_on_disk"` config and `"aot_inductor.package_constants_in_so"` config work independently.
- Use `load_pt2(package_path, load_weights_from_disk=True)` to load the weights from disk. `load_weights_from_disk` defaults to False.

Test Plan:
```
buck2 run @//mode/dev-nosan //caffe2/test/inductor:aot_inductor_package -- -r "test_package_shared_weights"
```

Tested with whisper at https://github.com/pytorch-labs/torchnative/pull/7

Rollback Plan:

Differential Revision: D74747190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155241
Approved by: https://github.com/desertfire
This commit is contained in:
Shangdi Yu
2025-06-19 17:17:17 +00:00
committed by PyTorch MergeBot
parent 6e185c5312
commit eaf704914e
11 changed files with 419 additions and 20 deletions

View File

@ -20,6 +20,7 @@ from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided, same
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.package import package_aoti
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import TestCase
from torch._inductor.utils import is_big_gpu, run_and_get_cpp_code
@ -27,6 +28,7 @@ from torch._utils_internal import full_aoti_runtime_assert
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.export import Dim, export, export_for_training
from torch.export.pt2_archive._package import load_pt2
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater
@ -5379,6 +5381,44 @@ class AOTInductorTestsTemplate:
output = runner_call(test_inputs)
self.assertEqual(expected, output)
def test_weight_on_disk_legacy(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 128, 2048, 4096
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
with torch.no_grad(), config.patch(
{
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
"aot_inductor.package": True,
}
):
aoti_files = AOTIRunnerUtil.legacy_compile(
model=model,
example_inputs=example_inputs,
)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_path = package_aoti(
f.name,
{"model": aoti_files},
)
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
loaded1 = pt2_contents.aoti_runners["model"]
self.assertEqual(loaded1(a), model(a))
def test_extract_constants_map(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):

View File

@ -20,6 +20,7 @@ from torch._inductor.package import AOTICompiledModel, load_package, package_aot
from torch._inductor.test_case import TestCase
from torch._inductor.utils import fresh_cache
from torch.export import Dim
from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents
from torch.testing._internal.common_utils import (
IS_FBCODE,
skipIfRocm,
@ -661,6 +662,124 @@ class TestAOTInductorPackage(TestCase):
output = compiled(test_inputs)
self.assertEqual(expected, output)
@skipif(
lambda device, package_cpp_only: package_cpp_only,
"No support for cpp only",
)
def test_package_shared_weights(self):
options = {
"aot_inductor.package": True,
"aot_inductor.package_cpp_only": self.package_cpp_only,
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
}
class Bar(torch.nn.Module):
def __init__(self, p1, p2):
super().__init__()
self.p1 = p1
self.register_buffer("p2", p2)
def forward(self):
self.p1 += 1
self.p2 += 1
return self.p1, self.p2
class Bar2(torch.nn.Module):
def __init__(self, p1, p2):
super().__init__()
self.p1 = p1
self.register_buffer("p2", p2[2:3])
def forward(self):
self.p1 += 3
self.p2 += 3
return self.p1, self.p2
x = torch.randn(3, 4)
y = torch.randn(3, 4)
buffer = torch.nn.Buffer(x.clone())
buffer2 = torch.nn.Buffer(y.clone())
bar1 = Bar(buffer, buffer2)
bar2 = Bar2(buffer, buffer2)
ep1 = torch.export.export(bar1, ())
ep2 = torch.export.export(bar2, ())
aoti_files1 = torch._inductor.aot_compile(ep1.module(), (), options=options)
aoti_files2 = torch._inductor.aot_compile(ep2.module(), (), options=options)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_path = package_aoti(
f.name,
{"model1": aoti_files1, "model2": aoti_files2},
)
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
loaded1 = pt2_contents.aoti_runners["model1"]
loaded2 = pt2_contents.aoti_runners["model2"]
# note that loading like below doesn't work, because new weights will be loaded
# for each load_package call.
# loaded1 = load_package(package_path, "model1")
# loaded2 = load_package(package_path, "model2")
result_1_p1, result_1_p2 = loaded1()
self.assertEqual(result_1_p1, x + 1)
self.assertEqual(result_1_p2, y + 1)
result_2_p1, result_2_p2 = loaded2()
# the result already incremented by 1 from the run above
self.assertEqual(result_2_p1, x + 4)
self.assertEqual(result_2_p2, y[2:3] + 4)
# note that the returned result will not change though p2 changed
self.assertEqual(result_1_p2, y + 1)
# test shared weights but user managed
gm1 = ep1.module()
gm2 = ep2.module()
load_weights_to_pt2_contents(
pt2_contents, {"model1": gm1.state_dict(), "model2": gm2.state_dict()}
)
result_1_p1, result_1_p2 = loaded1()
self.assertEqual(result_1_p1, x + 1)
self.assertEqual(result_1_p2, y + 1)
self.assertEqual(gm1.p1, x + 1)
self.assertEqual(gm1.p2, y + 1)
@skipif(
lambda device, package_cpp_only: package_cpp_only,
"No support for cpp only",
)
def test_package_weights_on_disk_nested_module(self):
options = {
"aot_inductor.package": True,
"aot_inductor.package_cpp_only": self.package_cpp_only,
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
}
# linear.weight's node name is linear_weight.
# This unit test tests that we package the right weight name
# `liear.weight`, but not `linear_weight`
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
return self.linear(x)
x = torch.randn(3, 3).to(self.device)
bar1 = Bar().to(self.device)
ep = torch.export.export(bar1, (x,))
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=options
)
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
loaded1 = pt2_contents.aoti_runners["model"]
self.assertEqual(loaded1(x), bar1(x))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -86,7 +86,7 @@ def aot_compile(
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
same_signature: bool = True,
) -> Union[list[str], str]:
) -> Union[list[Any], str]:
"""
Note: this function is not stable yet

View File

@ -15,6 +15,7 @@ from .standalone_compile import CompiledArtifact # noqa: TC001
if TYPE_CHECKING:
from torch._inductor.utils import InputType
from torch.export import ExportedProgram
from torch.export.pt2_archive._package_weights import Weights
from torch.types import FileLike
__all__ = [
@ -197,13 +198,13 @@ def _aoti_compile_and_package_inner(
path = [
os.path.splitext(file)[0]
for file in aoti_files
if os.path.splitext(file)[1] == ".so"
if isinstance(file, str) and os.path.splitext(file)[1] == ".so"
]
if len(path) == 0:
path = [
os.path.splitext(file)[0]
for file in aoti_files
if os.path.splitext(file)[1] == ".cpp"
if isinstance(file, str) and os.path.splitext(file)[1] == ".cpp"
]
package_path = path[0] + ".pt2"
@ -274,7 +275,7 @@ def aot_compile(
kwargs: Optional[dict[str, Any]] = None,
*,
options: Optional[dict[str, Any]] = None,
) -> Union[str, list[str]]:
) -> Union[str, list[Union[str, Weights]]]:
"""
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.

View File

@ -103,6 +103,7 @@ from torch.compiler._cache import (
CacheArtifactFactory,
CacheArtifactManager,
)
from torch.export.pt2_archive._package_weights import TensorProperties, Weights
from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
from torch.utils._ordered_set import OrderedSet
@ -1684,12 +1685,12 @@ class AotCodeCompiler:
*,
device_type: str,
additional_files: list[str],
) -> Union[list[str], str]:
) -> Union[list[Union[str, Weights]], str]:
"""
Returns the .so path, or returns a list of files that were generated if
config.aot_inductor.package=True.
"""
generated_files = additional_files
generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment]
if sys.platform == "win32":
raise RuntimeError("AotCodeCompiler not yet supported for inductor")
@ -1965,6 +1966,20 @@ class AotCodeCompiler:
else:
serialized_weights = b""
if config.aot_inductor.package_constants_on_disk:
# We need to return a storage key here because the original value tensor might be a clone
weights_dict = Weights(
{
graph.allocated_constant_name[name]: (
graph.get_original_value_of_constant(name),
TensorProperties(graph.constants[name]),
)
for name in graph.constants.keys()
if name not in graph.folded_constants
}
)
generated_files.append(weights_dict)
consts_size = len(serialized_weights)
# TODO: Fix mmap weights with cuda

View File

@ -128,6 +128,7 @@ if TYPE_CHECKING:
from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload
from torch.export.pt2_archive._package_weights import Weights
from .ir import ExternKernelNode
@ -1774,7 +1775,7 @@ def compile_fx_aot(
example_inputs_: list[InputType],
inner_compile: _CompileFxCallable = compile_fx_inner,
config_patches: Optional[dict[str, str]] = None,
) -> Union[list[str], str]:
) -> Union[list[Union[str, Weights]], str]:
assert isinstance(model_, GraphModule), model_
# [See NOTE] Unwrapping subclasses AOT
@ -1980,7 +1981,7 @@ def compile_fx(
config_patches: Optional[dict[str, Any]] = None,
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
ignore_shape_env: bool = False,
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str], Weights]:
"""
Main entry point for compiling given FX graph. Despite the fact that this
lives in :mod:`torch._inductor`, this function is responsible for calling

View File

@ -1353,6 +1353,9 @@ class aot_inductor:
# Experimental. Flag to control whether to include weight in .so
package_constants_in_so: bool = True
# Experimental. Flag to control whether to package weight separately on disk
package_constants_on_disk: bool = False
# Experimental. Controls automatic precompiling of common AOTI include files.
precompile_headers: bool = not is_fbcode()

View File

@ -61,6 +61,7 @@ if TYPE_CHECKING:
from torch._inductor import metrics
from torch._inductor.graph import GraphLowering
from torch._library.fake_class_registry import FakeScriptObject
from torch.export.pt2_archive._package_weights import Weights
from .compile_fx import _CompileFxKwargs
from .triton_bundler import TritonBundle
@ -718,7 +719,7 @@ class CompiledAOTI(OutputCode):
Class holding an AOTInductor compiled so.
"""
filename: Union[str, list[str]]
filename: Union[str, list[Union[str, Weights]]]
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")

View File

@ -3,12 +3,17 @@ import json
import logging
import os
import tempfile
from typing import IO, Union
from typing import IO
import torch
from torch._inductor import config
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export.pt2_archive._package import AOTICompiledModel, load_pt2, package_pt2
from torch.export.pt2_archive._package import (
AOTI_FILES,
AOTICompiledModel,
load_pt2,
package_pt2,
)
from torch.types import FileLike
@ -76,7 +81,7 @@ def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
def package_aoti(
archive_file: FileLike,
aoti_files: Union[list[str], dict[str, list[str]]],
aoti_files: AOTI_FILES,
) -> FileLike:
"""
Saves the AOTInductor generated files to the PT2Archive format.
@ -88,7 +93,10 @@ def package_aoti(
path to its AOTInductor generated files.
"""
return package_pt2(archive_file, aoti_files=aoti_files)
return package_pt2(
archive_file,
aoti_files=aoti_files,
)
def load_package(

View File

@ -1,17 +1,24 @@
import glob
import io
import json
import logging
import os
import tempfile
import zipfile
from dataclasses import dataclass
from typing import Any, IO, Optional, Union
from typing import Any, IO, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
import torch
import torch.utils._pytree as pytree
from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact
from torch.export._tree_utils import reorder_kwargs
from torch.export.exported_program import ExportedProgram
from torch.export.pt2_archive._package_weights import (
get_complete,
group_weights,
Weights,
)
from torch.export.pt2_archive.constants import (
AOTINDUCTOR_DIR,
ARCHIVE_FORMAT_PATH,
@ -24,12 +31,20 @@ from torch.export.pt2_archive.constants import (
MODELS_DIR,
MODELS_FILENAME_FORMAT,
SAMPLE_INPUTS_FILENAME_FORMAT,
WEIGHT_FILENAME_PREFIX,
WEIGHTS_DIR,
)
from torch.types import FileLike
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
DEFAULT_PICKLE_PROTOCOL = 2
AOTI_FILES: TypeAlias = Union[
list[Union[str, Weights]], dict[str, list[Union[str, Weights]]]
]
logger: logging.Logger = logging.getLogger(__name__)
@ -203,7 +218,8 @@ class PT2ArchiveReader:
def _package_aoti_files(
archive_writer: PT2ArchiveWriter,
aoti_files: Optional[Union[list[str], dict[str, list[str]]]],
aoti_files: Optional[AOTI_FILES],
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> None:
if aoti_files is None:
return
@ -213,13 +229,23 @@ def _package_aoti_files(
assert isinstance(aoti_files, dict)
all_weights: dict[str, Weights] = {} # model_name -> weight
weights_configs: dict[
str, dict[str, Any]
] = {} # model_name -> (weight_name -> (filename, shape, stride, offset))
for model_name, files in aoti_files.items():
num_so_files = 0
weights_configs[model_name] = {}
for file in files:
if file == "":
continue
if isinstance(file, Weights):
all_weights[model_name] = file
continue
if file.endswith(".so"):
num_so_files += 1
if num_so_files > 1:
@ -242,6 +268,35 @@ def _package_aoti_files(
file,
)
if len(all_weights) > 0:
# Dedup weights
grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights)
for idx, group in enumerate(grouped_tensors):
filename = f"{WEIGHT_FILENAME_PREFIX}{idx}"
model_name, weight_name = get_complete(group, all_weights)
complete_tensor, _ = all_weights[model_name].get_weight(weight_name)
buffer = io.BytesIO()
torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol)
archive_writer.write_bytes(
os.path.join(WEIGHTS_DIR, filename), buffer.getvalue()
)
for model_name, weight_name in group:
_, w_property = all_weights[model_name].get_weight(weight_name)
weights_configs[model_name][weight_name] = (
filename,
w_property.shape,
w_property.stride,
w_property.offset,
)
for model_name, weights_config in weights_configs.items():
archive_writer.write_string(
os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"),
json.dumps(weights_config),
)
logger.debug("packaging weights_config for model %s", model_name)
logger.debug(weights_config)
def _package_exported_programs(
archive_writer: PT2ArchiveWriter,
@ -263,6 +318,7 @@ def _package_exported_programs(
archive_writer.write_bytes(
MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program
)
# TODO:Consider dedup this with the weights saved in package_aoti_files
archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict)
archive_writer.write_bytes(
f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants
@ -289,7 +345,7 @@ def package_pt2(
exported_programs: Optional[
Union[ExportedProgram, dict[str, ExportedProgram]]
] = None,
aoti_files: Optional[Union[list[str], dict[str, list[str]]]] = None,
aoti_files: Optional[AOTI_FILES] = None,
extra_files: Optional[dict[str, Any]] = None,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
@ -347,8 +403,14 @@ def package_pt2(
f = os.fspath(f)
with PT2ArchiveWriter(f) as archive_writer:
_package_exported_programs(archive_writer, exported_programs)
_package_aoti_files(archive_writer, aoti_files)
_package_exported_programs(
archive_writer, exported_programs, pickle_protocol=pickle_protocol
)
_package_aoti_files(
archive_writer,
aoti_files,
pickle_protocol=pickle_protocol,
)
_package_extra_files(archive_writer, extra_files)
if isinstance(f, (io.IOBase, IO)):
@ -474,6 +536,7 @@ def load_pt2(
run_single_threaded: bool = False,
num_runners: int = 1,
device_index: int = -1,
load_weights_from_disk: bool = False,
) -> PT2ArchiveContents: # type: ignore[type-arg]
"""
Loads all the artifacts previously saved with ``package_pt2``.
@ -514,6 +577,8 @@ def load_pt2(
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
weights = {}
weight_maps = {}
with PT2ArchiveReader(f) as archive_reader:
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
if version != ARCHIVE_VERSION_VALUE:
@ -533,11 +598,23 @@ def load_pt2(
aoti_model_names: set[str] = set()
for file in file_names:
if file.startswith(AOTINDUCTOR_DIR):
file = file[len(AOTINDUCTOR_DIR) :] # remove data/aotinductor/ prefix
model_name = file.split("/")[
file_end = file[
len(AOTINDUCTOR_DIR) :
] # remove data/aotinductor/ prefix
model_name = file_end.split("/")[
0
] # split "model_name/...cpp" into "model_name"
aoti_model_names.add(model_name)
if load_weights_from_disk and file.endswith("weights_config.json"):
weight_map = json.loads(archive_reader.read_string(file))
weight_maps[model_name] = weight_map
elif load_weights_from_disk and file.startswith(WEIGHTS_DIR):
weight_file_name = file[
len(WEIGHTS_DIR) :
] # remove data/weights/ prefix
weight_bytes = archive_reader.read_bytes(file)
loaded_weight = torch.load(io.BytesIO(weight_bytes))
weights[weight_file_name] = loaded_weight
if isinstance(f, (io.IOBase, IO)):
if len(aoti_model_names) > 0:
@ -572,4 +649,37 @@ def load_pt2(
for model_name in aoti_model_names
}
if weight_maps:
for model_name in aoti_model_names:
model_weights = {}
for weight_name, (file, shape, stride, storage_offset) in weight_maps[
model_name
].items():
weight = weights[file]
model_weights[weight_name] = weight.as_strided(
shape, stride, storage_offset
)
# user_managed=True ensures the weights updates are shared by all runners.
aoti_runners[model_name].load_constants(
model_weights, check_full_update=True, user_managed=True
)
return PT2ArchiveContents(exported_programs, aoti_runners, extra_files)
def load_weights_to_pt2_contents(
pt2_contents: PT2ArchiveContents, weights_map: dict[str, Any]
) -> None:
"""
Load weights into the models in PT2 archive contents
Args:
pt2_contents (PT2ArchiveContents): The contents of the PT2 archive.
"""
for model_name, weights in weights_map.items():
if model_name not in pt2_contents.aoti_runners:
raise RuntimeError(f"Model {model_name} not found in PT2 archive contents.")
pt2_contents.aoti_runners[model_name].load_constants(
weights, check_full_update=True, user_managed=True
)

View File

@ -0,0 +1,101 @@
import collections
import torch
from torch.utils._ordered_set import OrderedSet
def _end_ptr(tensor: torch.Tensor) -> int:
if tensor.nelement():
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
else:
stop = tensor.data_ptr()
return stop
class TensorProperties:
def __init__(self, tensor: torch.Tensor):
# info about underlying storage
self.storage_ptr = tensor.untyped_storage().data_ptr()
self.storage_size = tensor.untyped_storage().nbytes()
# info to recover tensor
self.shape = tensor.shape
self.stride = tensor.stride()
self.offset = tensor.storage_offset()
self.start = tensor.data_ptr()
self.end = _end_ptr(tensor)
def is_complete(self) -> bool:
"""
Whehter the tensor completely overlaps with its underlying storage
"""
return (
self.start == self.storage_ptr
and self.end == self.storage_ptr + self.storage_size
)
class Weights(dict):
"""
A dictionary mapping from weight name to a tuple of (tensor, TensorProperties).
tensor represents the actual intial value of the weight.
TensorProperties represents the properties of the weight that are needed to recover the weight.
We use two separate entries because `tensor` could be a clone of the original weight tensor,
so it doesn't have the same property as the original weight (such as underlying storage pointer).
"""
def __init__(self, weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]]):
super().__init__(weight_dict)
def get_weight(self, name: str) -> tuple[torch.Tensor, TensorProperties]:
return self[name]
def get_weight_properties(self, name: str) -> TensorProperties:
return self[name][1]
def get_complete(
group: OrderedSet[tuple[str, str]], models_weights: dict[str, Weights]
) -> tuple[str, str]:
"""
`group` is a (model_name, weight_name) tuple.
`model_weights` is a dictionary mapping from model name to its Weights.
One of the tensor in `group` must be complete and they must share the
same underlying storage.
Returns the name of the complete tensor in the `group`. If multiple
tensors are complete, returns an arbitrary one.
"""
def get_tensor_properties(name_tuple: tuple[str, str]) -> TensorProperties:
# returns the tensor properties
(model_name, weight_name) = name_tuple
return models_weights[model_name].get_weight_properties(weight_name)
for name_tuple in group:
tensor_property = get_tensor_properties(name_tuple)
if tensor_property.is_complete():
return name_tuple
raise RuntimeError("No complete tensor found in the group!")
def group_weights(all_weights: dict[str, Weights]) -> list[OrderedSet[tuple[str, str]]]:
"""
Group weights that share the same underlying storage.
Returns a list of sets, each set contains a tuple of (model_name, weight_name).
"""
weights_dict: dict[int, OrderedSet[tuple[str, str]]] = collections.defaultdict(
OrderedSet
) # storage_key -> set(weight)
for model_name, weights in all_weights.items():
for weight_name, (_, properties) in weights.items():
weights_dict[properties.storage_ptr].add((model_name, weight_name))
return list(weights_dict.values())