mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python... * Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users. * Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config. * Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`. * `load_package` will load a singular model, given the model name. * The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows? Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374 Approved by: https://github.com/desertfire, https://github.com/malfet
226 lines
7.8 KiB
Python
226 lines
7.8 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import shlex
|
|
import subprocess
|
|
import zipfile
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch._inductor
|
|
import torch.utils._pytree as pytree
|
|
from torch._inductor import exc
|
|
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
|
from torch.export._tree_utils import reorder_kwargs
|
|
|
|
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class PT2ArchiveWriter:
|
|
def __init__(self, archive_path: str) -> None:
|
|
self.archive_path: str = archive_path
|
|
self.archive_file: Optional[zipfile.ZipFile] = None
|
|
|
|
def __enter__(self) -> "PT2ArchiveWriter":
|
|
assert self.archive_file is None
|
|
self.archive_file = zipfile.ZipFile(
|
|
self.archive_path, "w", compression=zipfile.ZIP_STORED
|
|
)
|
|
self.writestr("version", str(ARCHIVE_VERSION))
|
|
self.writestr("archive_format", "pt2")
|
|
return self
|
|
|
|
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
|
assert self.archive_file is not None
|
|
self.archive_file.close()
|
|
self.archive_file = None
|
|
return None
|
|
|
|
def writestr(self, name: str, data: Union[bytes, str]) -> None:
|
|
assert self.archive_file is not None
|
|
self.archive_file.writestr(name, data)
|
|
|
|
def write_file(self, name: str, file_path: str) -> None:
|
|
"""
|
|
Copy a file into the archive.
|
|
name: The destination file inside the archive.
|
|
file_path: The source file on disk.
|
|
"""
|
|
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
|
|
assert self.archive_file is not None
|
|
self.archive_file.write(file_path, arcname=name)
|
|
|
|
|
|
class PT2ArchiveReader:
|
|
def __init__(self, archive_path: str) -> None:
|
|
self.archive_path: str = archive_path
|
|
self.archive_file: Optional[zipfile.ZipFile] = None
|
|
|
|
def __enter__(self) -> "PT2ArchiveReader":
|
|
self.archive_file = zipfile.ZipFile(
|
|
self.archive_path, "r", compression=zipfile.ZIP_STORED
|
|
)
|
|
return self
|
|
|
|
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
|
if self.archive_file is not None:
|
|
self.archive_file.close()
|
|
return None
|
|
|
|
def read(self, name: str) -> bytes:
|
|
assert self.archive_file is not None
|
|
return self.archive_file.read(name)
|
|
|
|
def extract_to_path(self, member: str, path: str) -> str:
|
|
assert self.archive_file is not None
|
|
return self.archive_file.extract(member, path)
|
|
|
|
def extractall(self, path: str) -> None:
|
|
assert self.archive_file is not None
|
|
self.archive_file.extractall(path)
|
|
|
|
def get_file_names(self) -> List[str]:
|
|
assert self.archive_file is not None
|
|
return self.archive_file.namelist()
|
|
|
|
|
|
def _run_command_and_check(cmd: str) -> None:
|
|
cmd = shlex.split(cmd)
|
|
try:
|
|
subprocess.run(cmd, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
raise exc.CppCompileError(cmd, e.output) from e
|
|
|
|
|
|
def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
|
|
def get_aoti_file_with_suffix(suffix: str) -> str:
|
|
for file in aoti_files:
|
|
if file.endswith(suffix):
|
|
return file
|
|
raise RuntimeError(f"Unable to find file with suffix {suffix}")
|
|
|
|
# Compile all the files into a .so
|
|
cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
|
|
consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
|
|
|
|
file_name = os.path.splitext(cpp_file)[0]
|
|
|
|
# Parse compile flags and build the .o file
|
|
with open(file_name + "_compile_flags.json") as f:
|
|
compile_flags = json.load(f)
|
|
|
|
compile_options = BuildOptionsBase(**compile_flags)
|
|
object_builder = CppBuilder(
|
|
name=file_name,
|
|
sources=cpp_file,
|
|
BuildOption=compile_options,
|
|
)
|
|
compile_cmd = object_builder.get_command_line()
|
|
output_o = object_builder.get_target_file_path()
|
|
|
|
_run_command_and_check(compile_cmd)
|
|
|
|
# Parse linker flags and build the .so file
|
|
with open(file_name + "_linker_flags.json") as f:
|
|
linker_flags = json.load(f)
|
|
|
|
linker_options = BuildOptionsBase(**linker_flags)
|
|
so_builder = CppBuilder(
|
|
name=os.path.split(so_path)[-1],
|
|
sources=[output_o, consts_o],
|
|
BuildOption=linker_options,
|
|
output_dir=so_path,
|
|
)
|
|
link_cmd = so_builder.get_command_line()
|
|
output_so = so_builder.get_target_file_path()
|
|
|
|
_run_command_and_check(link_cmd)
|
|
|
|
# mmapped weights
|
|
serialized_weights_filename = file_name + "_serialized_weights.bin"
|
|
if serialized_weights_filename in aoti_files:
|
|
with open(serialized_weights_filename, "rb") as f_weights:
|
|
serialized_weights = f_weights.read()
|
|
|
|
with open(output_so, "a+b") as f_so:
|
|
so_size = f_so.tell()
|
|
# Page align the weights
|
|
f_so.write(b" " * (16384 - so_size % 16384))
|
|
f_so.write(serialized_weights)
|
|
|
|
return output_so
|
|
|
|
|
|
def package_aoti(archive_file: str, aoti_files: Union[str, Dict[str, str]]) -> str:
|
|
"""
|
|
Saves the AOTInductor generated files to the PT2Archive format.
|
|
|
|
Args:
|
|
archive_file: The file name to save the package to.
|
|
aoti_files: This can either be a singular path to a directory containing
|
|
the AOTInductor files, or a dictionary mapping the model name to the
|
|
path to its AOTInductor generated files.
|
|
"""
|
|
if isinstance(aoti_files, str):
|
|
aoti_files = {"model": aoti_files}
|
|
|
|
assert isinstance(aoti_files, dict)
|
|
assert archive_file.endswith(".pt2")
|
|
|
|
# Save using the PT2 packaging format
|
|
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
|
|
|
|
with PT2ArchiveWriter(archive_file) as archive_writer:
|
|
for model_name, aoti_output_dir in aoti_files.items():
|
|
log.debug(
|
|
"Packaging AOTInductor files from %s with model name, %s",
|
|
aoti_output_dir,
|
|
model_name,
|
|
)
|
|
for root, dirs, files in os.walk(aoti_output_dir):
|
|
for file in files:
|
|
log.debug(
|
|
"Saving AOTI generated file %s to archive in %s%s/%s",
|
|
os.path.join(root, file),
|
|
AOTINDUCTOR_DIR,
|
|
model_name,
|
|
file,
|
|
)
|
|
archive_writer.write_file(
|
|
f"{AOTINDUCTOR_DIR}{model_name}/{file}",
|
|
os.path.join(root, file),
|
|
)
|
|
return archive_file
|
|
|
|
|
|
class AOTICompiledModel:
|
|
"""
|
|
Callable AOT Inductor loaded model from a .pt2
|
|
"""
|
|
|
|
def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None:
|
|
self.loader = loader
|
|
|
|
def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
|
call_spec = self.loader.get_call_spec() # type: ignore[attr-defined]
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
|
|
flat_outputs = self.loader.run(flat_inputs) # type: ignore[attr-defined]
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
def get_metadata(self) -> Dict[str, str]:
|
|
return self.loader.get_metadata() # type: ignore[attr-defined]
|
|
|
|
|
|
def load_package(path: str, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
|
|
if not path.endswith(".pt2"):
|
|
raise RuntimeError("Unable to load package. Path must be a .pt2 file.")
|
|
|
|
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
|
|
return AOTICompiledModel(loader)
|