Files
pytorch/torch/_inductor/package/package.py
angelayi cd9ee49a69 [aoti] Add cpp loader (#135374)
* Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python...
* Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users.
* Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config.
* Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`.
* `load_package` will load a singular model, given the model name.
* The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows?

Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374
Approved by: https://github.com/desertfire, https://github.com/malfet
2024-09-11 03:00:01 +00:00

226 lines
7.8 KiB
Python

import json
import logging
import os
import shlex
import subprocess
import zipfile
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
import torch._inductor
import torch.utils._pytree as pytree
from torch._inductor import exc
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export._tree_utils import reorder_kwargs
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
log = logging.getLogger(__name__)
class PT2ArchiveWriter:
def __init__(self, archive_path: str) -> None:
self.archive_path: str = archive_path
self.archive_file: Optional[zipfile.ZipFile] = None
def __enter__(self) -> "PT2ArchiveWriter":
assert self.archive_file is None
self.archive_file = zipfile.ZipFile(
self.archive_path, "w", compression=zipfile.ZIP_STORED
)
self.writestr("version", str(ARCHIVE_VERSION))
self.writestr("archive_format", "pt2")
return self
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
assert self.archive_file is not None
self.archive_file.close()
self.archive_file = None
return None
def writestr(self, name: str, data: Union[bytes, str]) -> None:
assert self.archive_file is not None
self.archive_file.writestr(name, data)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
assert self.archive_file is not None
self.archive_file.write(file_path, arcname=name)
class PT2ArchiveReader:
def __init__(self, archive_path: str) -> None:
self.archive_path: str = archive_path
self.archive_file: Optional[zipfile.ZipFile] = None
def __enter__(self) -> "PT2ArchiveReader":
self.archive_file = zipfile.ZipFile(
self.archive_path, "r", compression=zipfile.ZIP_STORED
)
return self
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
if self.archive_file is not None:
self.archive_file.close()
return None
def read(self, name: str) -> bytes:
assert self.archive_file is not None
return self.archive_file.read(name)
def extract_to_path(self, member: str, path: str) -> str:
assert self.archive_file is not None
return self.archive_file.extract(member, path)
def extractall(self, path: str) -> None:
assert self.archive_file is not None
self.archive_file.extractall(path)
def get_file_names(self) -> List[str]:
assert self.archive_file is not None
return self.archive_file.namelist()
def _run_command_and_check(cmd: str) -> None:
cmd = shlex.split(cmd)
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output) from e
def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
def get_aoti_file_with_suffix(suffix: str) -> str:
for file in aoti_files:
if file.endswith(suffix):
return file
raise RuntimeError(f"Unable to find file with suffix {suffix}")
# Compile all the files into a .so
cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
file_name = os.path.splitext(cpp_file)[0]
# Parse compile flags and build the .o file
with open(file_name + "_compile_flags.json") as f:
compile_flags = json.load(f)
compile_options = BuildOptionsBase(**compile_flags)
object_builder = CppBuilder(
name=file_name,
sources=cpp_file,
BuildOption=compile_options,
)
compile_cmd = object_builder.get_command_line()
output_o = object_builder.get_target_file_path()
_run_command_and_check(compile_cmd)
# Parse linker flags and build the .so file
with open(file_name + "_linker_flags.json") as f:
linker_flags = json.load(f)
linker_options = BuildOptionsBase(**linker_flags)
so_builder = CppBuilder(
name=os.path.split(so_path)[-1],
sources=[output_o, consts_o],
BuildOption=linker_options,
output_dir=so_path,
)
link_cmd = so_builder.get_command_line()
output_so = so_builder.get_target_file_path()
_run_command_and_check(link_cmd)
# mmapped weights
serialized_weights_filename = file_name + "_serialized_weights.bin"
if serialized_weights_filename in aoti_files:
with open(serialized_weights_filename, "rb") as f_weights:
serialized_weights = f_weights.read()
with open(output_so, "a+b") as f_so:
so_size = f_so.tell()
# Page align the weights
f_so.write(b" " * (16384 - so_size % 16384))
f_so.write(serialized_weights)
return output_so
def package_aoti(archive_file: str, aoti_files: Union[str, Dict[str, str]]) -> str:
"""
Saves the AOTInductor generated files to the PT2Archive format.
Args:
archive_file: The file name to save the package to.
aoti_files: This can either be a singular path to a directory containing
the AOTInductor files, or a dictionary mapping the model name to the
path to its AOTInductor generated files.
"""
if isinstance(aoti_files, str):
aoti_files = {"model": aoti_files}
assert isinstance(aoti_files, dict)
assert archive_file.endswith(".pt2")
# Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
with PT2ArchiveWriter(archive_file) as archive_writer:
for model_name, aoti_output_dir in aoti_files.items():
log.debug(
"Packaging AOTInductor files from %s with model name, %s",
aoti_output_dir,
model_name,
)
for root, dirs, files in os.walk(aoti_output_dir):
for file in files:
log.debug(
"Saving AOTI generated file %s to archive in %s%s/%s",
os.path.join(root, file),
AOTINDUCTOR_DIR,
model_name,
file,
)
archive_writer.write_file(
f"{AOTINDUCTOR_DIR}{model_name}/{file}",
os.path.join(root, file),
)
return archive_file
class AOTICompiledModel:
"""
Callable AOT Inductor loaded model from a .pt2
"""
def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None:
self.loader = loader
def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
call_spec = self.loader.get_call_spec() # type: ignore[attr-defined]
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_outputs = self.loader.run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)
def get_metadata(self) -> Dict[str, str]:
return self.loader.get_metadata() # type: ignore[attr-defined]
def load_package(path: str, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
if not path.endswith(".pt2"):
raise RuntimeError("Unable to load package. Path must be a .pt2 file.")
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
return AOTICompiledModel(loader)