Files
pytorch/torch/export/pt2_archive/_package.py
2025-06-17 08:18:47 +00:00

576 lines
20 KiB
Python

import glob
import io
import logging
import os
import tempfile
import zipfile
from dataclasses import dataclass
from typing import Any, IO, Optional, Union
import torch
import torch.utils._pytree as pytree
from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact
from torch.export._tree_utils import reorder_kwargs
from torch.export.exported_program import ExportedProgram
from torch.export.pt2_archive.constants import (
AOTINDUCTOR_DIR,
ARCHIVE_FORMAT_PATH,
ARCHIVE_FORMAT_VALUE,
ARCHIVE_VERSION_PATH,
ARCHIVE_VERSION_VALUE,
CONSTANTS_DIR,
CUSTOM_OBJ_FILENAME_PREFIX,
EXTRA_DIR,
MODELS_DIR,
MODELS_FILENAME_FORMAT,
SAMPLE_INPUTS_FILENAME_FORMAT,
WEIGHTS_DIR,
)
from torch.types import FileLike
DEFAULT_PICKLE_PROTOCOL = 2
logger: logging.Logger = logging.getLogger(__name__)
def is_pt2_package(serialized_model: Union[bytes, str]) -> bool:
"""
Check if the serialized model is a PT2 Archive package.
"""
try:
zip_reader = zipfile.ZipFile(
io.BytesIO(serialized_model)
if isinstance(serialized_model, bytes)
else serialized_model
)
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
if archive_format_path in zip_reader.namelist():
return zip_reader.read(archive_format_path) == b"pt2"
except Exception as ex:
logger.info("Model is not a PT2 package: %s", str(ex))
return False
class PT2ArchiveWriter:
"""
Context manager for writing a PT2 archive.
"""
def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type]
# NOTICE: version here is different from the archive_version
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
# archive_version is the version of the PT2 archive spec, which write to /archive_version
self.archive_file.set_min_version(6)
def __enter__(self) -> "PT2ArchiveWriter":
return self
def __exit__(self, *args: Any) -> None:
if not self.has_record(ARCHIVE_FORMAT_PATH):
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
if not self.has_record(ARCHIVE_VERSION_PATH):
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
self.close()
def has_record(self, name: str) -> bool:
"""
Check if a record exists in the archive.
"""
return name in self.archive_file.get_all_written_records()
def count_prefix(self, prefix: str) -> int:
"""
Count the number of records that start with a given prefix.
"""
return sum(
1
for record in self.archive_file.get_all_written_records()
if record.startswith(prefix)
)
def write_bytes(self, name: str, data: bytes) -> None:
"""
Write a bytes object to the archive.
name: The destination file inside the archive.
data: The bytes object to write.
"""
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
self.archive_file.write_record(name, data, len(data))
def write_string(self, name: str, data: str) -> None:
"""
Write a string object to the archive.
name: The destination file inside the archive.
data: The string object to write.
"""
assert isinstance(data, str), f"Expected string but got {type(data)}"
data_bytes = data.encode()
self.write_bytes(name, data_bytes)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
with open(file_path, "rb") as f:
file_bytes = f.read()
self.write_bytes(name, file_bytes)
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
"""
Copy a folder into the archive.
archive_dir: The destination folder inside the archive.
folder_dir: The source folder on disk.
"""
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
file_paths = filter(
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
)
for file_path in file_paths:
filename = os.path.relpath(file_path, folder_dir)
archive_path = os.path.join(archive_dir, filename)
self.write_file(archive_path, file_path)
def close(self) -> None:
"""
Close the archive.
"""
self.archive_file.write_end_of_file()
class PT2ArchiveReader:
"""
Context manager for reading a PT2 archive.
"""
def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, (
"Invalid archive format"
)
def __enter__(self) -> "PT2ArchiveReader":
return self
def __exit__(self, *args: Any) -> None:
# torch._C.PyTorchFileReader doesn't have a close method
pass
def read_bytes(self, name: str) -> bytes:
"""
Read a bytes object from the archive.
name: The source file inside the archive.
"""
return self.archive_file.get_record(name)
def read_string(self, name: str) -> str:
"""
Read a string object from the archive.
name: The source file inside the archive.
"""
data = self.read_bytes(name)
return data.decode()
def archive_version(self) -> int:
"""
Get the archive version.
"""
try:
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
except Exception:
# if archive_version is not found, it means the archive is older than version 0.
# In this case, we assume the archive is version 0.
archive_version = "0"
return int(archive_version)
def get_file_names(self) -> list[str]:
"""
Get the file names in the archive.
"""
return self.archive_file.get_all_records()
def _package_aoti_files(
archive_writer: PT2ArchiveWriter,
aoti_files: Optional[Union[list[str], dict[str, list[str]]]],
) -> None:
if aoti_files is None:
return
if isinstance(aoti_files, list):
aoti_files = {"model": aoti_files}
assert isinstance(aoti_files, dict)
for model_name, files in aoti_files.items():
num_so_files = 0
for file in files:
if file == "":
continue
if file.endswith(".so"):
num_so_files += 1
if num_so_files > 1:
raise RuntimeError(
f"Multiple .so files found in {files}. "
"You might need to clear your cache "
"directory before calling aoti_compile again."
)
filename = os.path.basename(file)
if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX):
new_filepath = os.path.join(CONSTANTS_DIR, filename)
else:
new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename)
logger.debug(
"Saving AOTI generated file %s to archive in %s", file, new_filepath
)
archive_writer.write_file(
str(new_filepath),
file,
)
def _package_exported_programs(
archive_writer: PT2ArchiveWriter,
exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]],
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> None:
if exported_programs is None:
return
if isinstance(exported_programs, ExportedProgram):
exported_programs = {"model", exported_programs} # type: ignore[assignment]
assert isinstance(exported_programs, dict)
for model_name, ep in exported_programs.items():
artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol)
archive_writer.write_bytes(
MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program
)
archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict)
archive_writer.write_bytes(
f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants
)
archive_writer.write_bytes(
SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name),
artifact.example_inputs,
)
def _package_extra_files(
archive_writer: PT2ArchiveWriter, extra_files: Optional[dict[str, Any]]
) -> None:
if extra_files is None:
return
for extra_file_name, content in extra_files.items():
archive_writer.write_string(f"{EXTRA_DIR}{extra_file_name}", content)
def package_pt2(
f: FileLike,
*,
exported_programs: Optional[
Union[ExportedProgram, dict[str, ExportedProgram]]
] = None,
aoti_files: Optional[Union[list[str], dict[str, list[str]]]] = None,
extra_files: Optional[dict[str, Any]] = None,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> FileLike:
"""
Saves the artifacts to a PT2Archive format
(https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a).
The artifact can then be loaded using ``load_pt2``.
Args:
f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to
implement write and flush) or a string containing a file name.
exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]):
The exported program to save, or a dictionary mapping model name to an
exported program to save. The exported program will be saved under
models/*.json. If only one ExportedProgram is specified, this will
automatically be named "model".
aoti_files (Union[list[str], dict[str, list[str]]): A list of files
generated by AOTInductor via
``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``,
or a dictionary mapping model name to its AOTInductor generated files.
If only one set of files is specified, this will automatically be named
"model".
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
which will be stored as part of the pt2.
opset_version (Optional[Dict[str, int]]): A map of opset names
to the version of this opset
pickle_protocol: can be specified to override the default protocol
"""
assert not (
exported_programs is None and aoti_files is None and extra_files is None
), (
"No value passed in for `exported_programs`, `aoti_files`, and "
"`extra_files`, implying that you do not plan on saving anything."
)
if not (
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
):
# TODO: turn this into an error
logger.warning(
"Expect archive file to be a file ending in .pt2, or is a buffer. "
"Instead got {%s}",
f,
)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
with PT2ArchiveWriter(f) as archive_writer:
_package_exported_programs(archive_writer, exported_programs)
_package_aoti_files(archive_writer, aoti_files)
_package_extra_files(archive_writer, extra_files)
if isinstance(f, (io.IOBase, IO)):
f.seek(0)
return f
class AOTICompiledModel:
"""
Callable AOT Inductor loaded model from a .pt2
"""
def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None:
self.loader = loader
def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
call_spec = self.loader.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = self.loader.boxed_run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
def get_metadata(self) -> dict[str, str]:
return self.loader.get_metadata()
def load_constants(
self,
constants_map: dict[str, torch.Tensor],
*,
check_full_update: bool,
user_managed: bool = False,
) -> None:
"""
Given a mapping of constant fqns to tensors, load the constants into the model.
You can use ``get_constant_fqns`` to get the list of constant fqns that
are needed in the compiled model.
Args:
constants_map: A mapping of constant fqns to tensors.
check_full_update: Whether to add check to see if all the constants
are updated and have values.
"""
self.loader.load_constants(
constants_map, False, check_full_update, user_managed
)
def get_constant_fqns(self) -> list[str]:
return self.loader.get_constant_fqns()
def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel":
logger.warning(
"AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied."
)
return AOTICompiledModel(self.loader)
@dataclass
class PT2ArchiveContents:
exported_programs: dict[str, ExportedProgram]
aoti_runners: dict[str, AOTICompiledModel]
extra_files: dict[str, Any]
def _load_exported_programs(
archive_reader: PT2ArchiveReader,
file_names: list[str],
expected_opset_version: Optional[dict[str, int]],
) -> dict[str, ExportedProgram]:
exported_program_files = [
file for file in file_names if file.startswith(MODELS_DIR)
]
exported_programs = {}
for file in exported_program_files:
prefix, suffix = MODELS_FILENAME_FORMAT.split(
"{}"
) # split "models/{}.json" into "models/" and "json"
model_name = file[
len(prefix) : -len(suffix)
] # given "models/foo.json" we can now get "foo"
weights_file = f"{WEIGHTS_DIR}{model_name}.pt"
constants_file = f"{CONSTANTS_DIR}{model_name}.pt"
sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name)
serialized_exported_program = archive_reader.read_bytes(file)
serialized_weights = archive_reader.read_bytes(weights_file)
serialized_constants = archive_reader.read_bytes(constants_file)
serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file)
artifact: SerializedArtifact = SerializedArtifact(
serialized_exported_program,
serialized_weights,
serialized_constants,
serialized_sample_inputs,
)
# Deserialize ExportedProgram
ep = deserialize(artifact, expected_opset_version)
exported_programs[model_name] = ep
return exported_programs
def _load_extra_files(
archive_reader: PT2ArchiveReader, file_names: list[str]
) -> dict[str, Any]:
extra_files = [file for file in file_names if file.startswith(EXTRA_DIR)]
extra_file_contents: dict[str, Any] = {}
for file in extra_files:
contents = archive_reader.read_string(file)
extra_file_contents[file[len(EXTRA_DIR) :]] = contents
return extra_file_contents
def load_pt2(
f: FileLike,
*,
expected_opset_version: Optional[dict[str, int]] = None,
run_single_threaded: bool = False,
num_runners: int = 1,
device_index: int = -1,
) -> PT2ArchiveContents: # type: ignore[type-arg]
"""
Loads all the artifacts previously saved with ``package_pt2``.
Args:
f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to
implement write and flush) or a string containing a file name.
expected_opset_version (Optional[Dict[str, int]]): A map of opset names
to expected opset versions
num_runners (int): Number of runners to load AOTInductor artifacts
run_single_threaded (bool): Whether the model should be run without
thread synchronization logic. This is useful to avoid conflicts with
CUDAGraphs.
device_index (int): The index of the device to which the PT2 package is
to be loaded. By default, `device_index=-1` is used, which corresponds
to the device `cuda` when using CUDA. Passing `device_index=1` would
load the package to `cuda:1`, for example.
Returns:
A ``PT2ArchiveContents`` object which contains all the objects in the PT2.
"""
if not (
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
):
# TODO: turn this into an error in 2.9
logger.warning(
"Unable to load package. f must be a buffer or a file ending in "
".pt2. Instead got {%s}",
f,
)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
with PT2ArchiveReader(f) as archive_reader:
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
if version != ARCHIVE_VERSION_VALUE:
raise ValueError(
f"Saved archive version {version} does not match our current "
f"archive version {ARCHIVE_VERSION_VALUE}."
)
file_names = archive_reader.get_file_names()
exported_programs = _load_exported_programs(
archive_reader, file_names, expected_opset_version
)
extra_files = _load_extra_files(archive_reader, file_names)
# Get a list of AOTI model names
aoti_model_names: set[str] = set()
for file in file_names:
if file.startswith(AOTINDUCTOR_DIR):
file = file[len(AOTINDUCTOR_DIR) :] # remove data/aotinductor/ prefix
model_name = file.split("/")[
0
] # split "model_name/...cpp" into "model_name"
aoti_model_names.add(model_name)
if isinstance(f, (io.IOBase, IO)):
if len(aoti_model_names) > 0:
# Workaround for AOTIModelPackageLoader not reading buffers
with tempfile.NamedTemporaryFile(suffix=".pt2") as tf:
f.seek(0)
tf.write(f.read())
f.seek(0)
logger.debug("Writing buffer to tmp file located at %s.", tf.name)
aoti_runners = {
model_name: AOTICompiledModel(
torch._C._aoti.AOTIModelPackageLoader(
tf.name,
model_name,
run_single_threaded,
num_runners,
device_index,
)
)
for model_name in aoti_model_names
}
else:
aoti_runners = {}
else:
aoti_runners = {
model_name: AOTICompiledModel(
torch._C._aoti.AOTIModelPackageLoader(
f, model_name, run_single_threaded, num_runners, device_index
)
)
for model_name in aoti_model_names
}
return PT2ArchiveContents(exported_programs, aoti_runners, extra_files)