mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti] package weights to disk and dedup (#155241)
We package the weights and save them in `data/weights/` (`WEIGHTS_DIR`). In addition, we store a `weights_config.json` in the model folder for each model to specify which weight file corresponding to which weight name. Models can share weights. We dedup the weights based on their underlying storage (`tensor.untyped_storate()`). - Use `"aot_inductor.package_constants_on_disk": True` config to produce the `Weights` in aot_compile - If we see `Weights` in aoti_files, we'll automatically package them to disk - `"aot_inductor.package_constants_on_disk"` config and `"aot_inductor.package_constants_in_so"` config work independently. - Use `load_pt2(package_path, load_weights_from_disk=True)` to load the weights from disk. `load_weights_from_disk` defaults to False. Test Plan: ``` buck2 run @//mode/dev-nosan //caffe2/test/inductor:aot_inductor_package -- -r "test_package_shared_weights" ``` Tested with whisper at https://github.com/pytorch-labs/torchnative/pull/7 Rollback Plan: Differential Revision: D74747190 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155241 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
6e185c5312
commit
eaf704914e
@ -20,6 +20,7 @@ from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.testing import rand_strided, same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.package import package_aoti
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import is_big_gpu, run_and_get_cpp_code
|
||||
@ -27,6 +28,7 @@ from torch._utils_internal import full_aoti_runtime_assert
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
||||
from torch.export import Dim, export, export_for_training
|
||||
from torch.export.pt2_archive._package import load_pt2
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater
|
||||
@ -5379,6 +5381,44 @@ class AOTInductorTestsTemplate:
|
||||
output = runner_call(test_inputs)
|
||||
self.assertEqual(expected, output)
|
||||
|
||||
def test_weight_on_disk_legacy(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, n, k, device):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(n, k, device=device)
|
||||
self.bias = torch.randn(n, device=device)
|
||||
|
||||
def forward(self, a):
|
||||
return torch.nn.functional.linear(a, self.weight, self.bias)
|
||||
|
||||
M, N, K = 128, 2048, 4096
|
||||
model = Model(N, K, self.device)
|
||||
a = torch.randn(M, K, device=self.device)
|
||||
example_inputs = (a,)
|
||||
|
||||
with torch.no_grad(), config.patch(
|
||||
{
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
"aot_inductor.package_constants_on_disk": True,
|
||||
"aot_inductor.package": True,
|
||||
}
|
||||
):
|
||||
aoti_files = AOTIRunnerUtil.legacy_compile(
|
||||
model=model,
|
||||
example_inputs=example_inputs,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
package_path = package_aoti(
|
||||
f.name,
|
||||
{"model": aoti_files},
|
||||
)
|
||||
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
|
||||
loaded1 = pt2_contents.aoti_runners["model"]
|
||||
|
||||
self.assertEqual(loaded1(a), model(a))
|
||||
|
||||
def test_extract_constants_map(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, n, k, device):
|
||||
|
@ -20,6 +20,7 @@ from torch._inductor.package import AOTICompiledModel, load_package, package_aot
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.export import Dim
|
||||
from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
skipIfRocm,
|
||||
@ -661,6 +662,124 @@ class TestAOTInductorPackage(TestCase):
|
||||
output = compiled(test_inputs)
|
||||
self.assertEqual(expected, output)
|
||||
|
||||
@skipif(
|
||||
lambda device, package_cpp_only: package_cpp_only,
|
||||
"No support for cpp only",
|
||||
)
|
||||
def test_package_shared_weights(self):
|
||||
options = {
|
||||
"aot_inductor.package": True,
|
||||
"aot_inductor.package_cpp_only": self.package_cpp_only,
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
"aot_inductor.package_constants_on_disk": True,
|
||||
}
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self, p1, p2):
|
||||
super().__init__()
|
||||
self.p1 = p1
|
||||
self.register_buffer("p2", p2)
|
||||
|
||||
def forward(self):
|
||||
self.p1 += 1
|
||||
self.p2 += 1
|
||||
return self.p1, self.p2
|
||||
|
||||
class Bar2(torch.nn.Module):
|
||||
def __init__(self, p1, p2):
|
||||
super().__init__()
|
||||
self.p1 = p1
|
||||
self.register_buffer("p2", p2[2:3])
|
||||
|
||||
def forward(self):
|
||||
self.p1 += 3
|
||||
self.p2 += 3
|
||||
return self.p1, self.p2
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
y = torch.randn(3, 4)
|
||||
buffer = torch.nn.Buffer(x.clone())
|
||||
buffer2 = torch.nn.Buffer(y.clone())
|
||||
bar1 = Bar(buffer, buffer2)
|
||||
bar2 = Bar2(buffer, buffer2)
|
||||
ep1 = torch.export.export(bar1, ())
|
||||
ep2 = torch.export.export(bar2, ())
|
||||
aoti_files1 = torch._inductor.aot_compile(ep1.module(), (), options=options)
|
||||
aoti_files2 = torch._inductor.aot_compile(ep2.module(), (), options=options)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
package_path = package_aoti(
|
||||
f.name,
|
||||
{"model1": aoti_files1, "model2": aoti_files2},
|
||||
)
|
||||
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
|
||||
loaded1 = pt2_contents.aoti_runners["model1"]
|
||||
loaded2 = pt2_contents.aoti_runners["model2"]
|
||||
|
||||
# note that loading like below doesn't work, because new weights will be loaded
|
||||
# for each load_package call.
|
||||
# loaded1 = load_package(package_path, "model1")
|
||||
# loaded2 = load_package(package_path, "model2")
|
||||
|
||||
result_1_p1, result_1_p2 = loaded1()
|
||||
self.assertEqual(result_1_p1, x + 1)
|
||||
self.assertEqual(result_1_p2, y + 1)
|
||||
|
||||
result_2_p1, result_2_p2 = loaded2()
|
||||
# the result already incremented by 1 from the run above
|
||||
self.assertEqual(result_2_p1, x + 4)
|
||||
self.assertEqual(result_2_p2, y[2:3] + 4)
|
||||
|
||||
# note that the returned result will not change though p2 changed
|
||||
self.assertEqual(result_1_p2, y + 1)
|
||||
|
||||
# test shared weights but user managed
|
||||
gm1 = ep1.module()
|
||||
gm2 = ep2.module()
|
||||
load_weights_to_pt2_contents(
|
||||
pt2_contents, {"model1": gm1.state_dict(), "model2": gm2.state_dict()}
|
||||
)
|
||||
result_1_p1, result_1_p2 = loaded1()
|
||||
self.assertEqual(result_1_p1, x + 1)
|
||||
self.assertEqual(result_1_p2, y + 1)
|
||||
self.assertEqual(gm1.p1, x + 1)
|
||||
self.assertEqual(gm1.p2, y + 1)
|
||||
|
||||
@skipif(
|
||||
lambda device, package_cpp_only: package_cpp_only,
|
||||
"No support for cpp only",
|
||||
)
|
||||
def test_package_weights_on_disk_nested_module(self):
|
||||
options = {
|
||||
"aot_inductor.package": True,
|
||||
"aot_inductor.package_cpp_only": self.package_cpp_only,
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
"aot_inductor.package_constants_on_disk": True,
|
||||
}
|
||||
|
||||
# linear.weight's node name is linear_weight.
|
||||
# This unit test tests that we package the right weight name
|
||||
# `liear.weight`, but not `linear_weight`
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
x = torch.randn(3, 3).to(self.device)
|
||||
bar1 = Bar().to(self.device)
|
||||
ep = torch.export.export(bar1, (x,))
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
ep, inductor_configs=options
|
||||
)
|
||||
pt2_contents = load_pt2(package_path, load_weights_from_disk=True)
|
||||
loaded1 = pt2_contents.aoti_runners["model"]
|
||||
self.assertEqual(loaded1(x), bar1(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -86,7 +86,7 @@ def aot_compile(
|
||||
remove_runtime_assertions: bool = False,
|
||||
disable_constraint_solver: bool = False,
|
||||
same_signature: bool = True,
|
||||
) -> Union[list[str], str]:
|
||||
) -> Union[list[Any], str]:
|
||||
"""
|
||||
Note: this function is not stable yet
|
||||
|
||||
|
@ -15,6 +15,7 @@ from .standalone_compile import CompiledArtifact # noqa: TC001
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.utils import InputType
|
||||
from torch.export import ExportedProgram
|
||||
from torch.export.pt2_archive._package_weights import Weights
|
||||
from torch.types import FileLike
|
||||
|
||||
__all__ = [
|
||||
@ -197,13 +198,13 @@ def _aoti_compile_and_package_inner(
|
||||
path = [
|
||||
os.path.splitext(file)[0]
|
||||
for file in aoti_files
|
||||
if os.path.splitext(file)[1] == ".so"
|
||||
if isinstance(file, str) and os.path.splitext(file)[1] == ".so"
|
||||
]
|
||||
if len(path) == 0:
|
||||
path = [
|
||||
os.path.splitext(file)[0]
|
||||
for file in aoti_files
|
||||
if os.path.splitext(file)[1] == ".cpp"
|
||||
if isinstance(file, str) and os.path.splitext(file)[1] == ".cpp"
|
||||
]
|
||||
package_path = path[0] + ".pt2"
|
||||
|
||||
@ -274,7 +275,7 @@ def aot_compile(
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> Union[str, list[str]]:
|
||||
) -> Union[str, list[Union[str, Weights]]]:
|
||||
"""
|
||||
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
|
||||
|
||||
|
@ -103,6 +103,7 @@ from torch.compiler._cache import (
|
||||
CacheArtifactFactory,
|
||||
CacheArtifactManager,
|
||||
)
|
||||
from torch.export.pt2_archive._package_weights import TensorProperties, Weights
|
||||
from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX
|
||||
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -1684,12 +1685,12 @@ class AotCodeCompiler:
|
||||
*,
|
||||
device_type: str,
|
||||
additional_files: list[str],
|
||||
) -> Union[list[str], str]:
|
||||
) -> Union[list[Union[str, Weights]], str]:
|
||||
"""
|
||||
Returns the .so path, or returns a list of files that were generated if
|
||||
config.aot_inductor.package=True.
|
||||
"""
|
||||
generated_files = additional_files
|
||||
generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment]
|
||||
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("AotCodeCompiler not yet supported for inductor")
|
||||
@ -1965,6 +1966,20 @@ class AotCodeCompiler:
|
||||
else:
|
||||
serialized_weights = b""
|
||||
|
||||
if config.aot_inductor.package_constants_on_disk:
|
||||
# We need to return a storage key here because the original value tensor might be a clone
|
||||
weights_dict = Weights(
|
||||
{
|
||||
graph.allocated_constant_name[name]: (
|
||||
graph.get_original_value_of_constant(name),
|
||||
TensorProperties(graph.constants[name]),
|
||||
)
|
||||
for name in graph.constants.keys()
|
||||
if name not in graph.folded_constants
|
||||
}
|
||||
)
|
||||
generated_files.append(weights_dict)
|
||||
|
||||
consts_size = len(serialized_weights)
|
||||
|
||||
# TODO: Fix mmap weights with cuda
|
||||
|
@ -128,6 +128,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch._inductor.output_code import _StrideExprStr
|
||||
from torch._ops import OpOverload
|
||||
from torch.export.pt2_archive._package_weights import Weights
|
||||
|
||||
from .ir import ExternKernelNode
|
||||
|
||||
@ -1774,7 +1775,7 @@ def compile_fx_aot(
|
||||
example_inputs_: list[InputType],
|
||||
inner_compile: _CompileFxCallable = compile_fx_inner,
|
||||
config_patches: Optional[dict[str, str]] = None,
|
||||
) -> Union[list[str], str]:
|
||||
) -> Union[list[Union[str, Weights]], str]:
|
||||
assert isinstance(model_, GraphModule), model_
|
||||
|
||||
# [See NOTE] Unwrapping subclasses AOT
|
||||
@ -1980,7 +1981,7 @@ def compile_fx(
|
||||
config_patches: Optional[dict[str, Any]] = None,
|
||||
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
|
||||
ignore_shape_env: bool = False,
|
||||
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
|
||||
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str], Weights]:
|
||||
"""
|
||||
Main entry point for compiling given FX graph. Despite the fact that this
|
||||
lives in :mod:`torch._inductor`, this function is responsible for calling
|
||||
|
@ -1353,6 +1353,9 @@ class aot_inductor:
|
||||
# Experimental. Flag to control whether to include weight in .so
|
||||
package_constants_in_so: bool = True
|
||||
|
||||
# Experimental. Flag to control whether to package weight separately on disk
|
||||
package_constants_on_disk: bool = False
|
||||
|
||||
# Experimental. Controls automatic precompiling of common AOTI include files.
|
||||
precompile_headers: bool = not is_fbcode()
|
||||
|
||||
|
@ -61,6 +61,7 @@ if TYPE_CHECKING:
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch.export.pt2_archive._package_weights import Weights
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
from .triton_bundler import TritonBundle
|
||||
@ -718,7 +719,7 @@ class CompiledAOTI(OutputCode):
|
||||
Class holding an AOTInductor compiled so.
|
||||
"""
|
||||
|
||||
filename: Union[str, list[str]]
|
||||
filename: Union[str, list[Union[str, Weights]]]
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
|
@ -3,12 +3,17 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import IO, Union
|
||||
from typing import IO
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
||||
from torch.export.pt2_archive._package import AOTICompiledModel, load_pt2, package_pt2
|
||||
from torch.export.pt2_archive._package import (
|
||||
AOTI_FILES,
|
||||
AOTICompiledModel,
|
||||
load_pt2,
|
||||
package_pt2,
|
||||
)
|
||||
from torch.types import FileLike
|
||||
|
||||
|
||||
@ -76,7 +81,7 @@ def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
||||
|
||||
def package_aoti(
|
||||
archive_file: FileLike,
|
||||
aoti_files: Union[list[str], dict[str, list[str]]],
|
||||
aoti_files: AOTI_FILES,
|
||||
) -> FileLike:
|
||||
"""
|
||||
Saves the AOTInductor generated files to the PT2Archive format.
|
||||
@ -88,7 +93,10 @@ def package_aoti(
|
||||
path to its AOTInductor generated files.
|
||||
"""
|
||||
|
||||
return package_pt2(archive_file, aoti_files=aoti_files)
|
||||
return package_pt2(
|
||||
archive_file,
|
||||
aoti_files=aoti_files,
|
||||
)
|
||||
|
||||
|
||||
def load_package(
|
||||
|
@ -1,17 +1,24 @@
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, IO, Optional, Union
|
||||
from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact
|
||||
from torch.export._tree_utils import reorder_kwargs
|
||||
from torch.export.exported_program import ExportedProgram
|
||||
from torch.export.pt2_archive._package_weights import (
|
||||
get_complete,
|
||||
group_weights,
|
||||
Weights,
|
||||
)
|
||||
from torch.export.pt2_archive.constants import (
|
||||
AOTINDUCTOR_DIR,
|
||||
ARCHIVE_FORMAT_PATH,
|
||||
@ -24,12 +31,20 @@ from torch.export.pt2_archive.constants import (
|
||||
MODELS_DIR,
|
||||
MODELS_FILENAME_FORMAT,
|
||||
SAMPLE_INPUTS_FILENAME_FORMAT,
|
||||
WEIGHT_FILENAME_PREFIX,
|
||||
WEIGHTS_DIR,
|
||||
)
|
||||
from torch.types import FileLike
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
DEFAULT_PICKLE_PROTOCOL = 2
|
||||
AOTI_FILES: TypeAlias = Union[
|
||||
list[Union[str, Weights]], dict[str, list[Union[str, Weights]]]
|
||||
]
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
@ -203,7 +218,8 @@ class PT2ArchiveReader:
|
||||
|
||||
def _package_aoti_files(
|
||||
archive_writer: PT2ArchiveWriter,
|
||||
aoti_files: Optional[Union[list[str], dict[str, list[str]]]],
|
||||
aoti_files: Optional[AOTI_FILES],
|
||||
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
|
||||
) -> None:
|
||||
if aoti_files is None:
|
||||
return
|
||||
@ -213,13 +229,23 @@ def _package_aoti_files(
|
||||
|
||||
assert isinstance(aoti_files, dict)
|
||||
|
||||
all_weights: dict[str, Weights] = {} # model_name -> weight
|
||||
weights_configs: dict[
|
||||
str, dict[str, Any]
|
||||
] = {} # model_name -> (weight_name -> (filename, shape, stride, offset))
|
||||
|
||||
for model_name, files in aoti_files.items():
|
||||
num_so_files = 0
|
||||
weights_configs[model_name] = {}
|
||||
|
||||
for file in files:
|
||||
if file == "":
|
||||
continue
|
||||
|
||||
if isinstance(file, Weights):
|
||||
all_weights[model_name] = file
|
||||
continue
|
||||
|
||||
if file.endswith(".so"):
|
||||
num_so_files += 1
|
||||
if num_so_files > 1:
|
||||
@ -242,6 +268,35 @@ def _package_aoti_files(
|
||||
file,
|
||||
)
|
||||
|
||||
if len(all_weights) > 0:
|
||||
# Dedup weights
|
||||
grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights)
|
||||
for idx, group in enumerate(grouped_tensors):
|
||||
filename = f"{WEIGHT_FILENAME_PREFIX}{idx}"
|
||||
model_name, weight_name = get_complete(group, all_weights)
|
||||
complete_tensor, _ = all_weights[model_name].get_weight(weight_name)
|
||||
buffer = io.BytesIO()
|
||||
torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol)
|
||||
archive_writer.write_bytes(
|
||||
os.path.join(WEIGHTS_DIR, filename), buffer.getvalue()
|
||||
)
|
||||
for model_name, weight_name in group:
|
||||
_, w_property = all_weights[model_name].get_weight(weight_name)
|
||||
weights_configs[model_name][weight_name] = (
|
||||
filename,
|
||||
w_property.shape,
|
||||
w_property.stride,
|
||||
w_property.offset,
|
||||
)
|
||||
|
||||
for model_name, weights_config in weights_configs.items():
|
||||
archive_writer.write_string(
|
||||
os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"),
|
||||
json.dumps(weights_config),
|
||||
)
|
||||
logger.debug("packaging weights_config for model %s", model_name)
|
||||
logger.debug(weights_config)
|
||||
|
||||
|
||||
def _package_exported_programs(
|
||||
archive_writer: PT2ArchiveWriter,
|
||||
@ -263,6 +318,7 @@ def _package_exported_programs(
|
||||
archive_writer.write_bytes(
|
||||
MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program
|
||||
)
|
||||
# TODO:Consider dedup this with the weights saved in package_aoti_files
|
||||
archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict)
|
||||
archive_writer.write_bytes(
|
||||
f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants
|
||||
@ -289,7 +345,7 @@ def package_pt2(
|
||||
exported_programs: Optional[
|
||||
Union[ExportedProgram, dict[str, ExportedProgram]]
|
||||
] = None,
|
||||
aoti_files: Optional[Union[list[str], dict[str, list[str]]]] = None,
|
||||
aoti_files: Optional[AOTI_FILES] = None,
|
||||
extra_files: Optional[dict[str, Any]] = None,
|
||||
opset_version: Optional[dict[str, int]] = None,
|
||||
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
|
||||
@ -347,8 +403,14 @@ def package_pt2(
|
||||
f = os.fspath(f)
|
||||
|
||||
with PT2ArchiveWriter(f) as archive_writer:
|
||||
_package_exported_programs(archive_writer, exported_programs)
|
||||
_package_aoti_files(archive_writer, aoti_files)
|
||||
_package_exported_programs(
|
||||
archive_writer, exported_programs, pickle_protocol=pickle_protocol
|
||||
)
|
||||
_package_aoti_files(
|
||||
archive_writer,
|
||||
aoti_files,
|
||||
pickle_protocol=pickle_protocol,
|
||||
)
|
||||
_package_extra_files(archive_writer, extra_files)
|
||||
|
||||
if isinstance(f, (io.IOBase, IO)):
|
||||
@ -474,6 +536,7 @@ def load_pt2(
|
||||
run_single_threaded: bool = False,
|
||||
num_runners: int = 1,
|
||||
device_index: int = -1,
|
||||
load_weights_from_disk: bool = False,
|
||||
) -> PT2ArchiveContents: # type: ignore[type-arg]
|
||||
"""
|
||||
Loads all the artifacts previously saved with ``package_pt2``.
|
||||
@ -514,6 +577,8 @@ def load_pt2(
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
|
||||
weights = {}
|
||||
weight_maps = {}
|
||||
with PT2ArchiveReader(f) as archive_reader:
|
||||
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
|
||||
if version != ARCHIVE_VERSION_VALUE:
|
||||
@ -533,11 +598,23 @@ def load_pt2(
|
||||
aoti_model_names: set[str] = set()
|
||||
for file in file_names:
|
||||
if file.startswith(AOTINDUCTOR_DIR):
|
||||
file = file[len(AOTINDUCTOR_DIR) :] # remove data/aotinductor/ prefix
|
||||
model_name = file.split("/")[
|
||||
file_end = file[
|
||||
len(AOTINDUCTOR_DIR) :
|
||||
] # remove data/aotinductor/ prefix
|
||||
model_name = file_end.split("/")[
|
||||
0
|
||||
] # split "model_name/...cpp" into "model_name"
|
||||
aoti_model_names.add(model_name)
|
||||
if load_weights_from_disk and file.endswith("weights_config.json"):
|
||||
weight_map = json.loads(archive_reader.read_string(file))
|
||||
weight_maps[model_name] = weight_map
|
||||
elif load_weights_from_disk and file.startswith(WEIGHTS_DIR):
|
||||
weight_file_name = file[
|
||||
len(WEIGHTS_DIR) :
|
||||
] # remove data/weights/ prefix
|
||||
weight_bytes = archive_reader.read_bytes(file)
|
||||
loaded_weight = torch.load(io.BytesIO(weight_bytes))
|
||||
weights[weight_file_name] = loaded_weight
|
||||
|
||||
if isinstance(f, (io.IOBase, IO)):
|
||||
if len(aoti_model_names) > 0:
|
||||
@ -572,4 +649,37 @@ def load_pt2(
|
||||
for model_name in aoti_model_names
|
||||
}
|
||||
|
||||
if weight_maps:
|
||||
for model_name in aoti_model_names:
|
||||
model_weights = {}
|
||||
for weight_name, (file, shape, stride, storage_offset) in weight_maps[
|
||||
model_name
|
||||
].items():
|
||||
weight = weights[file]
|
||||
model_weights[weight_name] = weight.as_strided(
|
||||
shape, stride, storage_offset
|
||||
)
|
||||
|
||||
# user_managed=True ensures the weights updates are shared by all runners.
|
||||
aoti_runners[model_name].load_constants(
|
||||
model_weights, check_full_update=True, user_managed=True
|
||||
)
|
||||
|
||||
return PT2ArchiveContents(exported_programs, aoti_runners, extra_files)
|
||||
|
||||
|
||||
def load_weights_to_pt2_contents(
|
||||
pt2_contents: PT2ArchiveContents, weights_map: dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Load weights into the models in PT2 archive contents
|
||||
|
||||
Args:
|
||||
pt2_contents (PT2ArchiveContents): The contents of the PT2 archive.
|
||||
"""
|
||||
for model_name, weights in weights_map.items():
|
||||
if model_name not in pt2_contents.aoti_runners:
|
||||
raise RuntimeError(f"Model {model_name} not found in PT2 archive contents.")
|
||||
pt2_contents.aoti_runners[model_name].load_constants(
|
||||
weights, check_full_update=True, user_managed=True
|
||||
)
|
||||
|
101
torch/export/pt2_archive/_package_weights.py
Normal file
101
torch/export/pt2_archive/_package_weights.py
Normal file
@ -0,0 +1,101 @@
|
||||
import collections
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def _end_ptr(tensor: torch.Tensor) -> int:
|
||||
if tensor.nelement():
|
||||
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
else:
|
||||
stop = tensor.data_ptr()
|
||||
return stop
|
||||
|
||||
|
||||
class TensorProperties:
|
||||
def __init__(self, tensor: torch.Tensor):
|
||||
# info about underlying storage
|
||||
self.storage_ptr = tensor.untyped_storage().data_ptr()
|
||||
self.storage_size = tensor.untyped_storage().nbytes()
|
||||
|
||||
# info to recover tensor
|
||||
self.shape = tensor.shape
|
||||
self.stride = tensor.stride()
|
||||
self.offset = tensor.storage_offset()
|
||||
|
||||
self.start = tensor.data_ptr()
|
||||
self.end = _end_ptr(tensor)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""
|
||||
Whehter the tensor completely overlaps with its underlying storage
|
||||
"""
|
||||
return (
|
||||
self.start == self.storage_ptr
|
||||
and self.end == self.storage_ptr + self.storage_size
|
||||
)
|
||||
|
||||
|
||||
class Weights(dict):
|
||||
"""
|
||||
A dictionary mapping from weight name to a tuple of (tensor, TensorProperties).
|
||||
tensor represents the actual intial value of the weight.
|
||||
TensorProperties represents the properties of the weight that are needed to recover the weight.
|
||||
|
||||
We use two separate entries because `tensor` could be a clone of the original weight tensor,
|
||||
so it doesn't have the same property as the original weight (such as underlying storage pointer).
|
||||
"""
|
||||
|
||||
def __init__(self, weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]]):
|
||||
super().__init__(weight_dict)
|
||||
|
||||
def get_weight(self, name: str) -> tuple[torch.Tensor, TensorProperties]:
|
||||
return self[name]
|
||||
|
||||
def get_weight_properties(self, name: str) -> TensorProperties:
|
||||
return self[name][1]
|
||||
|
||||
|
||||
def get_complete(
|
||||
group: OrderedSet[tuple[str, str]], models_weights: dict[str, Weights]
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
`group` is a (model_name, weight_name) tuple.
|
||||
`model_weights` is a dictionary mapping from model name to its Weights.
|
||||
|
||||
One of the tensor in `group` must be complete and they must share the
|
||||
same underlying storage.
|
||||
|
||||
Returns the name of the complete tensor in the `group`. If multiple
|
||||
tensors are complete, returns an arbitrary one.
|
||||
"""
|
||||
|
||||
def get_tensor_properties(name_tuple: tuple[str, str]) -> TensorProperties:
|
||||
# returns the tensor properties
|
||||
(model_name, weight_name) = name_tuple
|
||||
return models_weights[model_name].get_weight_properties(weight_name)
|
||||
|
||||
for name_tuple in group:
|
||||
tensor_property = get_tensor_properties(name_tuple)
|
||||
if tensor_property.is_complete():
|
||||
return name_tuple
|
||||
|
||||
raise RuntimeError("No complete tensor found in the group!")
|
||||
|
||||
|
||||
def group_weights(all_weights: dict[str, Weights]) -> list[OrderedSet[tuple[str, str]]]:
|
||||
"""
|
||||
Group weights that share the same underlying storage.
|
||||
|
||||
Returns a list of sets, each set contains a tuple of (model_name, weight_name).
|
||||
"""
|
||||
|
||||
weights_dict: dict[int, OrderedSet[tuple[str, str]]] = collections.defaultdict(
|
||||
OrderedSet
|
||||
) # storage_key -> set(weight)
|
||||
|
||||
for model_name, weights in all_weights.items():
|
||||
for weight_name, (_, properties) in weights.items():
|
||||
weights_dict[properties.storage_ptr].add((model_name, weight_name))
|
||||
|
||||
return list(weights_dict.values())
|
Reference in New Issue
Block a user