import glob import io import logging import os import tempfile import zipfile from dataclasses import dataclass from typing import Any, IO, Optional, Union import torch import torch.utils._pytree as pytree from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact from torch.export._tree_utils import reorder_kwargs from torch.export.exported_program import ExportedProgram from torch.export.pt2_archive.constants import ( AOTINDUCTOR_DIR, ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE, ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE, CONSTANTS_DIR, CUSTOM_OBJ_FILENAME_PREFIX, EXTRA_DIR, MODELS_DIR, MODELS_FILENAME_FORMAT, SAMPLE_INPUTS_FILENAME_FORMAT, WEIGHTS_DIR, ) from torch.types import FileLike DEFAULT_PICKLE_PROTOCOL = 2 logger: logging.Logger = logging.getLogger(__name__) def is_pt2_package(serialized_model: Union[bytes, str]) -> bool: """ Check if the serialized model is a PT2 Archive package. """ try: zip_reader = zipfile.ZipFile( io.BytesIO(serialized_model) if isinstance(serialized_model, bytes) else serialized_model ) root_folder = zip_reader.namelist()[0].split(os.path.sep)[0] archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}" if archive_format_path in zip_reader.namelist(): return zip_reader.read(archive_format_path) == b"pt2" except Exception as ex: logger.info("Model is not a PT2 package: %s", str(ex)) return False class PT2ArchiveWriter: """ Context manager for writing a PT2 archive. """ def __init__(self, archive_path_or_buffer: FileLike): self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] # NOTICE: version here is different from the archive_version # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version # archive_version is the version of the PT2 archive spec, which write to /archive_version self.archive_file.set_min_version(6) def __enter__(self) -> "PT2ArchiveWriter": return self def __exit__(self, *args: Any) -> None: if not self.has_record(ARCHIVE_FORMAT_PATH): self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE) if not self.has_record(ARCHIVE_VERSION_PATH): self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE) self.close() def has_record(self, name: str) -> bool: """ Check if a record exists in the archive. """ return name in self.archive_file.get_all_written_records() def count_prefix(self, prefix: str) -> int: """ Count the number of records that start with a given prefix. """ return sum( 1 for record in self.archive_file.get_all_written_records() if record.startswith(prefix) ) def write_bytes(self, name: str, data: bytes) -> None: """ Write a bytes object to the archive. name: The destination file inside the archive. data: The bytes object to write. """ assert isinstance(data, bytes), f"Expected bytes but got {type(data)}" self.archive_file.write_record(name, data, len(data)) def write_string(self, name: str, data: str) -> None: """ Write a string object to the archive. name: The destination file inside the archive. data: The string object to write. """ assert isinstance(data, str), f"Expected string but got {type(data)}" data_bytes = data.encode() self.write_bytes(name, data_bytes) def write_file(self, name: str, file_path: str) -> None: """ Copy a file into the archive. name: The destination file inside the archive. file_path: The source file on disk. """ assert os.path.isfile(file_path), f"{file_path} is not a valid file path" with open(file_path, "rb") as f: file_bytes = f.read() self.write_bytes(name, file_bytes) def write_folder(self, archive_dir: str, folder_dir: str) -> None: """ Copy a folder into the archive. archive_dir: The destination folder inside the archive. folder_dir: The source folder on disk. """ assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path" file_paths = filter( os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True) ) for file_path in file_paths: filename = os.path.relpath(file_path, folder_dir) archive_path = os.path.join(archive_dir, filename) self.write_file(archive_path, file_path) def close(self) -> None: """ Close the archive. """ self.archive_file.write_end_of_file() class PT2ArchiveReader: """ Context manager for reading a PT2 archive. """ def __init__(self, archive_path_or_buffer: FileLike): self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( "Invalid archive format" ) def __enter__(self) -> "PT2ArchiveReader": return self def __exit__(self, *args: Any) -> None: # torch._C.PyTorchFileReader doesn't have a close method pass def read_bytes(self, name: str) -> bytes: """ Read a bytes object from the archive. name: The source file inside the archive. """ return self.archive_file.get_record(name) def read_string(self, name: str) -> str: """ Read a string object from the archive. name: The source file inside the archive. """ data = self.read_bytes(name) return data.decode() def archive_version(self) -> int: """ Get the archive version. """ try: archive_version = self.read_string(ARCHIVE_VERSION_PATH) except Exception: # if archive_version is not found, it means the archive is older than version 0. # In this case, we assume the archive is version 0. archive_version = "0" return int(archive_version) def get_file_names(self) -> list[str]: """ Get the file names in the archive. """ return self.archive_file.get_all_records() def _package_aoti_files( archive_writer: PT2ArchiveWriter, aoti_files: Optional[Union[list[str], dict[str, list[str]]]], ) -> None: if aoti_files is None: return if isinstance(aoti_files, list): aoti_files = {"model": aoti_files} assert isinstance(aoti_files, dict) for model_name, files in aoti_files.items(): num_so_files = 0 for file in files: if file == "": continue if file.endswith(".so"): num_so_files += 1 if num_so_files > 1: raise RuntimeError( f"Multiple .so files found in {files}. " "You might need to clear your cache " "directory before calling aoti_compile again." ) filename = os.path.basename(file) if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX): new_filepath = os.path.join(CONSTANTS_DIR, filename) else: new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename) logger.debug( "Saving AOTI generated file %s to archive in %s", file, new_filepath ) archive_writer.write_file( str(new_filepath), file, ) def _package_exported_programs( archive_writer: PT2ArchiveWriter, exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]], opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, ) -> None: if exported_programs is None: return if isinstance(exported_programs, ExportedProgram): exported_programs = {"model", exported_programs} # type: ignore[assignment] assert isinstance(exported_programs, dict) for model_name, ep in exported_programs.items(): artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol) archive_writer.write_bytes( MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program ) archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict) archive_writer.write_bytes( f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants ) archive_writer.write_bytes( SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name), artifact.example_inputs, ) def _package_extra_files( archive_writer: PT2ArchiveWriter, extra_files: Optional[dict[str, Any]] ) -> None: if extra_files is None: return for extra_file_name, content in extra_files.items(): archive_writer.write_string(f"{EXTRA_DIR}{extra_file_name}", content) def package_pt2( f: FileLike, *, exported_programs: Optional[ Union[ExportedProgram, dict[str, ExportedProgram]] ] = None, aoti_files: Optional[Union[list[str], dict[str, list[str]]]] = None, extra_files: Optional[dict[str, Any]] = None, opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, ) -> FileLike: """ Saves the artifacts to a PT2Archive format (https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a). The artifact can then be loaded using ``load_pt2``. Args: f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to implement write and flush) or a string containing a file name. exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): The exported program to save, or a dictionary mapping model name to an exported program to save. The exported program will be saved under models/*.json. If only one ExportedProgram is specified, this will automatically be named "model". aoti_files (Union[list[str], dict[str, list[str]]): A list of files generated by AOTInductor via ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, or a dictionary mapping model name to its AOTInductor generated files. If only one set of files is specified, this will automatically be named "model". extra_files (Optional[Dict[str, Any]]): Map from filename to contents which will be stored as part of the pt2. opset_version (Optional[Dict[str, int]]): A map of opset names to the version of this opset pickle_protocol: can be specified to override the default protocol """ assert not ( exported_programs is None and aoti_files is None and extra_files is None ), ( "No value passed in for `exported_programs`, `aoti_files`, and " "`extra_files`, implying that you do not plan on saving anything." ) if not ( (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) ): # TODO: turn this into an error logger.warning( "Expect archive file to be a file ending in .pt2, or is a buffer. " "Instead got {%s}", f, ) if isinstance(f, (str, os.PathLike)): f = os.fspath(f) with PT2ArchiveWriter(f) as archive_writer: _package_exported_programs(archive_writer, exported_programs) _package_aoti_files(archive_writer, aoti_files) _package_extra_files(archive_writer, extra_files) if isinstance(f, (io.IOBase, IO)): f.seek(0) return f class AOTICompiledModel: """ Callable AOT Inductor loaded model from a .pt2 """ def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: self.loader = loader def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] call_spec = self.loader.get_call_spec() in_spec = pytree.treespec_loads(call_spec[0]) out_spec = pytree.treespec_loads(call_spec[1]) flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] flat_outputs = self.loader.boxed_run(flat_inputs) return pytree.tree_unflatten(flat_outputs, out_spec) def get_metadata(self) -> dict[str, str]: return self.loader.get_metadata() def load_constants( self, constants_map: dict[str, torch.Tensor], *, check_full_update: bool, user_managed: bool = False, ) -> None: """ Given a mapping of constant fqns to tensors, load the constants into the model. You can use ``get_constant_fqns`` to get the list of constant fqns that are needed in the compiled model. Args: constants_map: A mapping of constant fqns to tensors. check_full_update: Whether to add check to see if all the constants are updated and have values. """ self.loader.load_constants( constants_map, False, check_full_update, user_managed ) def get_constant_fqns(self) -> list[str]: return self.loader.get_constant_fqns() def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel": logger.warning( "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." ) return AOTICompiledModel(self.loader) @dataclass class PT2ArchiveContents: exported_programs: dict[str, ExportedProgram] aoti_runners: dict[str, AOTICompiledModel] extra_files: dict[str, Any] def _load_exported_programs( archive_reader: PT2ArchiveReader, file_names: list[str], expected_opset_version: Optional[dict[str, int]], ) -> dict[str, ExportedProgram]: exported_program_files = [ file for file in file_names if file.startswith(MODELS_DIR) ] exported_programs = {} for file in exported_program_files: prefix, suffix = MODELS_FILENAME_FORMAT.split( "{}" ) # split "models/{}.json" into "models/" and "json" model_name = file[ len(prefix) : -len(suffix) ] # given "models/foo.json" we can now get "foo" weights_file = f"{WEIGHTS_DIR}{model_name}.pt" constants_file = f"{CONSTANTS_DIR}{model_name}.pt" sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) serialized_exported_program = archive_reader.read_bytes(file) serialized_weights = archive_reader.read_bytes(weights_file) serialized_constants = archive_reader.read_bytes(constants_file) serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) artifact: SerializedArtifact = SerializedArtifact( serialized_exported_program, serialized_weights, serialized_constants, serialized_sample_inputs, ) # Deserialize ExportedProgram ep = deserialize(artifact, expected_opset_version) exported_programs[model_name] = ep return exported_programs def _load_extra_files( archive_reader: PT2ArchiveReader, file_names: list[str] ) -> dict[str, Any]: extra_files = [file for file in file_names if file.startswith(EXTRA_DIR)] extra_file_contents: dict[str, Any] = {} for file in extra_files: contents = archive_reader.read_string(file) extra_file_contents[file[len(EXTRA_DIR) :]] = contents return extra_file_contents def load_pt2( f: FileLike, *, expected_opset_version: Optional[dict[str, int]] = None, run_single_threaded: bool = False, num_runners: int = 1, device_index: int = -1, ) -> PT2ArchiveContents: # type: ignore[type-arg] """ Loads all the artifacts previously saved with ``package_pt2``. Args: f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to implement write and flush) or a string containing a file name. expected_opset_version (Optional[Dict[str, int]]): A map of opset names to expected opset versions num_runners (int): Number of runners to load AOTInductor artifacts run_single_threaded (bool): Whether the model should be run without thread synchronization logic. This is useful to avoid conflicts with CUDAGraphs. device_index (int): The index of the device to which the PT2 package is to be loaded. By default, `device_index=-1` is used, which corresponds to the device `cuda` when using CUDA. Passing `device_index=1` would load the package to `cuda:1`, for example. Returns: A ``PT2ArchiveContents`` object which contains all the objects in the PT2. """ if not ( (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) ): # TODO: turn this into an error in 2.9 logger.warning( "Unable to load package. f must be a buffer or a file ending in " ".pt2. Instead got {%s}", f, ) if isinstance(f, (str, os.PathLike)): f = os.fspath(f) with PT2ArchiveReader(f) as archive_reader: version = archive_reader.read_string(ARCHIVE_VERSION_PATH) if version != ARCHIVE_VERSION_VALUE: raise ValueError( f"Saved archive version {version} does not match our current " f"archive version {ARCHIVE_VERSION_VALUE}." ) file_names = archive_reader.get_file_names() exported_programs = _load_exported_programs( archive_reader, file_names, expected_opset_version ) extra_files = _load_extra_files(archive_reader, file_names) # Get a list of AOTI model names aoti_model_names: set[str] = set() for file in file_names: if file.startswith(AOTINDUCTOR_DIR): file = file[len(AOTINDUCTOR_DIR) :] # remove data/aotinductor/ prefix model_name = file.split("/")[ 0 ] # split "model_name/...cpp" into "model_name" aoti_model_names.add(model_name) if isinstance(f, (io.IOBase, IO)): if len(aoti_model_names) > 0: # Workaround for AOTIModelPackageLoader not reading buffers with tempfile.NamedTemporaryFile(suffix=".pt2") as tf: f.seek(0) tf.write(f.read()) f.seek(0) logger.debug("Writing buffer to tmp file located at %s.", tf.name) aoti_runners = { model_name: AOTICompiledModel( torch._C._aoti.AOTIModelPackageLoader( tf.name, model_name, run_single_threaded, num_runners, device_index, ) ) for model_name in aoti_model_names } else: aoti_runners = {} else: aoti_runners = { model_name: AOTICompiledModel( torch._C._aoti.AOTIModelPackageLoader( f, model_name, run_single_threaded, num_runners, device_index ) ) for model_name in aoti_model_names } return PT2ArchiveContents(exported_programs, aoti_runners, extra_files)