[aoti] Add cpp loader (#135374)

* Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python...
* Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users.
* Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config.
* Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`.
* `load_package` will load a singular model, given the model name.
* The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows?

Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374
Approved by: https://github.com/desertfire, https://github.com/malfet
This commit is contained in:
angelayi
2024-09-10 16:08:07 -07:00
committed by PyTorch MergeBot
parent 26e5572dd2
commit cd9ee49a69
22 changed files with 890 additions and 246 deletions

View File

@ -337,3 +337,8 @@ onnxscript==0.1.0.dev20240817
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
parameterized==0.8.1
#Description: Parameterizes unittests, both the tests themselves and the entire testing class
#Pinned versions:
#test that import:

View File

@ -43,6 +43,9 @@ python -m pip install z3-solver==4.12.2.0
# Install tlparse for test\dynamo\test_structured_trace.py UTs.
python -m pip install tlparse==0.3.25
# Install parameterized
python -m pip install parameterized==0.8.1
run_tests() {
# Run nvidia-smi if available
for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do

View File

@ -31,3 +31,4 @@ optree==0.12.1
# NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in
# which the stringify metadata is wrong when escaping double quote
protobuf==3.20.2
parameterized==0.8.1

View File

@ -466,6 +466,7 @@ lazy_tensor_core_python_sources = [
]
inductor_core_resources = [
"torch/csrc/inductor/aoti_package/model_package_loader.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
@ -841,6 +842,7 @@ libtorch_python_core_sources = [
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/inductor/aoti_package/pybind.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
"torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp",

View File

@ -1324,6 +1324,7 @@ def main():
"include/torch/csrc/distributed/autograd/rpc_messages/*.h",
"include/torch/csrc/dynamo/*.h",
"include/torch/csrc/inductor/*.h",
"include/torch/csrc/inductor/aoti_package/*.h",
"include/torch/csrc/inductor/aoti_runner/*.h",
"include/torch/csrc/inductor/aoti_runtime/*.h",
"include/torch/csrc/inductor/aoti_torch/*.h",

View File

@ -1,78 +1,95 @@
# Owner(s): ["module: inductor"]
import copy
import sys
import tempfile
import unittest
from parameterized import parameterized_class
import torch
from torch._inductor import config
from torch._inductor.package import load_package
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
from torch._inductor.test_case import TestCase
from torch.testing._internal import common_utils
from torch.export import Dim
from torch.testing._internal.common_utils import IS_FBCODE
from torch.testing._internal.triton_utils import HAS_CUDA
try:
try:
from .test_torchinductor import copy_tests
except ImportError:
from test_torchinductor import copy_tests
except (unittest.SkipTest, ImportError) as e:
if __name__ == "__main__":
sys.exit(0)
raise
def compile(model, example_inputs, dynamic_shapes, options, device):
def compile(
model,
args,
kwargs=None,
*,
dynamic_shapes=None,
package_path=None,
inductor_configs=None,
) -> AOTICompiledModel:
ep = torch.export.export(
model,
example_inputs,
args,
kwargs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
gm = ep.module()
package_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type]
compiled_model = load_package(package_path, device)
return compiled_model
package_path = torch._inductor.aoti_compile_and_package(
ep, args, kwargs, package_path=package_path, inductor_configs=inductor_configs
) # type: ignore[arg-type]
loaded = load_package(package_path)
return loaded
def check_model(
self: TestCase,
model,
example_inputs,
options=None,
dynamic_shapes=None,
disable_constraint_solver=False,
atol=None,
rtol=None,
):
with torch.no_grad(), config.patch(
{
"aot_inductor.package": True,
# TODO: "aot_inductor.force_mmap_weights": True,
}
):
torch.manual_seed(0)
model = model.to(self.device)
ref_model = copy.deepcopy(model)
ref_inputs = copy.deepcopy(example_inputs)
expected = ref_model(*ref_inputs)
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
@unittest.skipIf(IS_FBCODE, "This is for OSS only")
@parameterized_class(
[
{"device": "cpu", "package_cpp_only": False},
{"device": "cpu", "package_cpp_only": True},
]
+ (
[
{"device": "cuda", "package_cpp_only": False},
{"device": "cuda", "package_cpp_only": True},
]
if sys.platform != "darwin"
else []
),
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}",
)
class TestAOTInductorPackage(TestCase):
def check_model(
self: TestCase,
model,
example_inputs,
inductor_configs=None,
dynamic_shapes=None,
disable_constraint_solver=False,
atol=None,
rtol=None,
) -> AOTICompiledModel:
with torch.no_grad():
torch.manual_seed(0)
model = model.to(self.device)
ref_model = copy.deepcopy(model)
ref_inputs = copy.deepcopy(example_inputs)
expected = ref_model(*ref_inputs)
torch.manual_seed(0)
compiled_model = compile(
model,
example_inputs,
dynamic_shapes,
options,
self.device,
)
inductor_configs = inductor_configs or {}
inductor_configs["aot_inductor.package_cpp_only"] = self.package_cpp_only
actual = compiled_model(*example_inputs)
torch.manual_seed(0)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
compiled_model = compile(
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
inductor_configs=inductor_configs,
package_path=f.name,
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
actual = compiled_model(*example_inputs)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
return compiled_model
class AOTInductorTestsTemplate:
def test_add(self):
class Model(torch.nn.Module):
def forward(self, x, y):
@ -99,34 +116,84 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(), example_inputs)
def test_metadata(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
metadata = {"dummy": "moo"}
compiled_model = self.check_model(
Model(),
example_inputs,
inductor_configs={"aot_inductor.metadata": metadata},
)
@unittest.skipIf(sys.platform == "darwin" or IS_FBCODE, "No CUDA on MacOS")
class AOTInductorTestPackagedABICompatibleCuda(TestCase):
device = "cuda"
check_model = check_model
loaded_metadata = compiled_model.get_metadata() # type: ignore[attr-defined]
self.assertEqual(loaded_metadata.get("dummy"), "moo")
copy_tests(
AOTInductorTestsTemplate,
AOTInductorTestPackagedABICompatibleCuda,
"packaged_abi_compatible_cuda",
)
def test_multiple_methods(self):
options = {
"aot_inductor.package": True,
"aot_inductor.package_cpp_only": self.package_cpp_only,
}
class Model1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@unittest.skipIf(IS_FBCODE, "This is for OSS only")
class AOTInductorTestPackagedABICompatibleCpu(TestCase):
device = "cpu"
check_model = check_model
def forward(self, a, b):
return torch.cat([a, b], dim=0)
b = torch.randn(3, 4, device=self.device)
dim0_a = Dim("dim0_a", min=1, max=10)
dim0_b = Dim("dim0_b", min=1, max=20)
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
example_inputs1 = (
torch.randn(2, 4, device=self.device),
torch.randn(3, 4, device=self.device),
)
ep1 = torch.export.export(
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
)
aoti_files1 = torch._inductor.aot_compile(
ep1.module(), example_inputs1, options=options
)
class Model2(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.device = device
def forward(self, x):
t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
t = torch.sqrt(t * 3)
return x * t
example_inputs2 = (torch.randn(5, 5, device=self.device),)
ep2 = torch.export.export(Model2(self.device), example_inputs2)
aoti_files2 = torch._inductor.aot_compile(
ep2.module(), example_inputs2, options=options
)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_path = package_aoti(
f.name, {"model1": aoti_files1, "model2": aoti_files2}
)
loaded1 = load_package(package_path, "model1")
loaded2 = load_package(package_path, "model2")
self.assertEqual(loaded1(*example_inputs1), ep1.module()(*example_inputs1))
self.assertEqual(loaded2(*example_inputs2), ep2.module()(*example_inputs2))
copy_tests(
AOTInductorTestsTemplate,
AOTInductorTestPackagedABICompatibleCpu,
"packaged_abi_compatible_cpu",
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -18,3 +18,6 @@ def alloc_tensor_by_stealing_from_void_ptr(
class AOTIModelContainerRunnerCpu: ...
class AOTIModelContainerRunnerCuda: ...
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
class AOTIModelPackageLoader: ...

View File

@ -30,6 +30,48 @@ def compile(
return compile_fx(gm, example_inputs, config_patches=options)
def aoti_compile_and_package(
exported_program,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
package_path: Optional[str] = None,
inductor_configs: Optional[Dict[str, Any]] = None,
) -> str:
"""
Compiles the exported program with AOTInductor, and packages it into a .pt2
file specified by the input package_path.
"""
from torch._inductor.package import package_aoti
from torch.export import ExportedProgram
if not isinstance(exported_program, ExportedProgram):
raise ValueError("Only ExportedProgram is supported")
assert package_path is None or package_path.endswith(".pt2")
inductor_configs = inductor_configs or {}
if inductor_configs.get("aot_inductor.output_path"):
raise RuntimeError(
"Please pass in a package path to aot_inductor_compile() instead "
"of setting the aot_inductor.output_path config."
)
inductor_configs["aot_inductor.package"] = True
m = exported_program.module()
assert isinstance(m, torch.fx.GraphModule)
aoti_files = aot_compile(m, args, kwargs, options=inductor_configs) # type: ignore[arg-type]
if package_path is None:
package_path = aoti_files + ".pt2"
res = package_aoti(package_path, aoti_files)
assert res == package_path
return package_path
def aot_compile(
gm: torch.fx.GraphModule,
args: Tuple[Any],

View File

@ -1719,10 +1719,23 @@ class AotCodeCompiler:
# Currently, this only support serializing extern nodes in fbcode
# Eventually, we should also have a serializer for OSS.
if serialized_extern_kernel_nodes:
output_json = os.path.splitext(input_path)[0] + ".json"
with open(output_json, "w") as f:
extern_kernel_nodes_json = os.path.splitext(input_path)[0] + ".json"
with open(extern_kernel_nodes_json, "w") as f:
f.write(serialized_extern_kernel_nodes)
metadata = config.aot_inductor.metadata
metadata["AOTI_DEVICE_KEY"] = device_type
# Save user provided metadata
meta_json = os.path.splitext(input_path)[0] + "_metadata.json"
for k, v in config.aot_inductor.metadata.items():
assert isinstance(k, str) and isinstance(
v, (str)
), "Metadata must only contain strings"
with open(meta_json, "w") as f:
f.write(json.dumps(config.aot_inductor.metadata))
output_so = (
config.aot_inductor.output_path
if specified_so_name
@ -1755,54 +1768,29 @@ class AotCodeCompiler:
if config.aot_inductor.force_mmap_weights:
use_mmap_weights = True
if config.aot_inductor.package:
(
object_output_name,
object_output_dir,
) = get_name_and_dir_from_output_file_path(input_path)
object_build_options = CppTorchDeviceOptions(
vec_isa=picked_vec_isa,
device_type=device_type,
aot_mode=graph.aot_mode,
compile_only=True,
use_absolute_path=use_absolute_path,
use_mmap_weights=use_mmap_weights,
)
object_builder = CppBuilder(
name=object_output_name,
sources=input_path,
output_dir=object_output_dir,
BuildOption=object_build_options,
)
compile_cmd = object_builder.get_command_line()
output_o = object_builder.get_target_file_path()
(
object_output_name,
object_output_dir,
) = get_name_and_dir_from_output_file_path(input_path)
object_build_options = CppTorchDeviceOptions(
vec_isa=picked_vec_isa,
device_type=device_type,
aot_mode=graph.aot_mode,
compile_only=True,
use_absolute_path=use_absolute_path,
use_mmap_weights=use_mmap_weights,
)
object_builder = CppBuilder(
name=object_output_name,
sources=input_path,
output_dir=object_output_dir,
BuildOption=object_build_options,
)
compile_cmd = object_builder.get_command_line()
output_o = object_builder.get_target_file_path()
compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json"
object_build_options.save_flags_to_file(compile_flags)
else:
(
object_output_name,
object_output_dir,
) = get_name_and_dir_from_output_file_path(input_path)
object_build_options = CppTorchDeviceOptions(
vec_isa=picked_vec_isa,
device_type=device_type,
aot_mode=graph.aot_mode,
compile_only=True,
use_absolute_path=use_absolute_path,
use_mmap_weights=use_mmap_weights,
)
object_builder = CppBuilder(
name=object_output_name,
sources=input_path,
output_dir=object_output_dir,
BuildOption=object_build_options,
)
compile_cmd = object_builder.get_command_line()
output_o = object_builder.get_target_file_path()
log.debug("aot compilation command: %s", compile_cmd)
log.debug("aot compilation command: %s", compile_cmd)
if not config.aot_inductor.package_cpp_only:
if fbcode_aot_cpu_re:
output_o = os.path.splitext(input_path)[0] + ".o"
compile_file(input_path, output_o, compile_cmd.split())
@ -1810,6 +1798,10 @@ class AotCodeCompiler:
else:
run_command_and_check(compile_cmd)
if config.aot_inductor.package:
compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json"
object_build_options.save_flags_to_file(compile_flags)
def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes:
def _pad_to_alignment(raw_bytes: bytes) -> bytes:
padded_bytes = raw_bytes.ljust(
@ -1859,29 +1851,37 @@ class AotCodeCompiler:
"darwin": _compile_consts_darwin,
}[sys.platform](aot_constants)
if config.aot_inductor.package:
output_name, output_dir = get_name_and_dir_from_output_file_path(
output_so
)
so_build_options = CppTorchDeviceOptions(
vec_isa=picked_vec_isa,
device_type=device_type,
aot_mode=graph.aot_mode,
use_absolute_path=use_absolute_path,
)
so_builder = CppBuilder(
name=output_name,
sources=[output_o, consts_o],
output_dir=output_dir,
BuildOption=so_build_options,
)
link_cmd = so_builder.get_command_line()
output_so = so_builder.get_target_file_path()
output_name, output_dir = get_name_and_dir_from_output_file_path(output_so)
so_build_options = CppTorchDeviceOptions(
vec_isa=picked_vec_isa,
device_type=device_type,
aot_mode=graph.aot_mode,
use_absolute_path=use_absolute_path,
)
so_builder = CppBuilder(
name=output_name,
sources=[output_o, consts_o],
output_dir=output_dir,
BuildOption=so_build_options,
)
link_cmd = so_builder.get_command_line()
output_so = so_builder.get_target_file_path()
log.debug("aot linkage command: %s", link_cmd)
# Append cmds to the end of codegen-ed wrapper file
with open(input_path, "a") as f:
f.write("\n")
f.write(f"// Compile cmd\n// {compile_cmd}\n")
f.write(f"// Link cmd\n// {link_cmd}\n")
if config.aot_inductor.package:
linker_flags = os.path.splitext(input_path)[0] + "_linker_flags.json"
so_build_options.save_flags_to_file(linker_flags)
from torch._inductor.package import package_aoti
if config.aot_inductor.package_cpp_only:
# If we only want to package the cpp, then we need to save the
# weights separately into a bin, and we also need to prevent compiling the so
if use_mmap_weights:
weight_file = (
@ -1891,28 +1891,7 @@ class AotCodeCompiler:
f_weights.write(serialized_weights)
f_weights.write(struct.pack("q", magic_number))
archive_path = package_aoti(os.path.split(input_path)[0])
return archive_path
else:
output_name, output_dir = get_name_and_dir_from_output_file_path(
output_so
)
so_build_options = CppTorchDeviceOptions(
vec_isa=picked_vec_isa,
device_type=device_type,
aot_mode=graph.aot_mode,
use_absolute_path=use_absolute_path,
)
so_builder = CppBuilder(
name=output_name,
sources=[output_o, consts_o],
output_dir=output_dir,
BuildOption=so_build_options,
)
link_cmd = so_builder.get_command_line()
output_so = so_builder.get_target_file_path()
log.debug("aot linkage command: %s", link_cmd)
if fbcode_aot_cpu_re:
output_so = (
config.aot_inductor.output_path
@ -1937,11 +1916,10 @@ class AotCodeCompiler:
f_so.write(serialized_weights)
f_so.write(struct.pack("q", magic_number))
# Append cmds to the end of codegen-ed wrapper file
with open(input_path, "a") as f:
f.write("\n")
f.write(f"// Compile cmd\n// {compile_cmd}\n")
f.write(f"// Link cmd\n// {link_cmd}\n")
if config.aot_inductor.package:
# We want to return the directory that contains all the AOTI
# generated files, not just the so
return os.path.split(output_so)[0]
return output_so

View File

@ -1132,6 +1132,7 @@ def compile_fx_aot(
if config_patches is None
else {**config_patches, "cpp_wrapper": True}
)
if (
"aot_inductor.output_path" not in config_patches
and not config.aot_inductor.output_path

View File

@ -1006,9 +1006,11 @@ class aot_inductor:
)
# Serialized tree spec for flattening inputs
# TODO: Move this into metadata
serialized_in_spec = ""
# Serialized tree spec for flattening outputs
# TODO: Move this into metadata
serialized_out_spec = ""
# flag to decide whether to create a submodule for constant graph.
@ -1019,6 +1021,11 @@ class aot_inductor:
force_mmap_weights: bool = False
package: bool = False
package_cpp_only: bool = False
# Dictionary of metadata users might want to save to pass to the runtime.
# TODO: Move this somewhere else, since it's no longer really a config
metadata: Dict[str, str] = {}
class cuda:

View File

@ -1 +1 @@
from .package import load_package, package_aoti
from .package import AOTICompiledModel, load_package, package_aoti

View File

@ -1,24 +1,25 @@
import glob
import json
import logging
import os
import shlex
import subprocess
import tempfile
import zipfile
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Dict, List, Optional, Union
import torch
import torch._inductor
import torch.utils._pytree as pytree
from torch._inductor import config, exc
from torch._inductor import exc
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export._tree_utils import reorder_kwargs
from .build_package import build_package_contents
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
log = logging.getLogger(__name__)
class PT2ArchiveWriter:
def __init__(self, archive_path: str) -> None:
self.archive_path: str = archive_path
@ -154,84 +155,71 @@ def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
return output_so
def package_aoti(aoti_output_dir: str) -> str:
def package_aoti(archive_file: str, aoti_files: Union[str, Dict[str, str]]) -> str:
"""
Saves the AOTInductor generated files to the PT2Archive format.
Args:
archive_file: The file name to save the package to.
aoti_files: This can either be a singular path to a directory containing
the AOTInductor files, or a dictionary mapping the model name to the
path to its AOTInductor generated files.
"""
if isinstance(aoti_files, str):
aoti_files = {"model": aoti_files}
assert isinstance(aoti_files, dict)
assert archive_file.endswith(".pt2")
# Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
with PT2ArchiveWriter(archive_file) as archive_writer:
for model_name, aoti_output_dir in aoti_files.items():
log.debug(
"Packaging AOTInductor files from %s with model name, %s",
aoti_output_dir,
model_name,
)
for root, dirs, files in os.walk(aoti_output_dir):
for file in files:
log.debug(
"Saving AOTI generated file %s to archive in %s%s/%s",
os.path.join(root, file),
AOTINDUCTOR_DIR,
model_name,
file,
)
archive_writer.write_file(
f"{AOTINDUCTOR_DIR}{model_name}/{file}",
os.path.join(root, file),
)
return archive_file
class AOTICompiledModel:
"""
Callable AOT Inductor loaded model from a .pt2
"""
# Add a makefile and python script
build_package_filename = "build_package.py"
with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f:
f.write(build_package_contents)
def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None:
self.loader = loader
with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f:
f.write(f"all:\n\tpython3 {build_package_filename}\n")
if config.aot_inductor.output_path.endswith(".so"):
raise RuntimeError(
"Unable to save package as a .so. It should be a .pt2 format or a directory."
)
elif config.aot_inductor.output_path.endswith(".pt2"):
# Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
archive_path = config.aot_inductor.output_path
with PT2ArchiveWriter(archive_path) as archive_writer:
package_files = glob.glob(f"{aoti_output_dir}/*")
for path in package_files:
filename = os.path.basename(path)
archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path)
return archive_path
else:
# Directly put the files into the directory, without any archiving
return aoti_output_dir
def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg]
if path.endswith(".so"):
raise RuntimeError(
"Unable to load .so. It should be a .pt2 format or a directory."
)
elif path.endswith(".pt2"):
so_path = os.path.splitext(path)[0]
with PT2ArchiveReader(path) as archive_reader:
file_names = archive_reader.get_file_names()
with tempfile.TemporaryDirectory() as tmp_dir:
archive_reader.extractall(tmp_dir)
file_names = archive_reader.get_file_names()
aoti_files = [
file for file in file_names if file.startswith(AOTINDUCTOR_DIR)
]
so_path = compile_so(tmp_dir, aoti_files, so_path)
else:
assert os.path.isdir(path), "Must specify a directory or a .pt2 file"
aoti_files = [
os.path.join(root, file)
for root, dirs, files in os.walk(path)
for file in files
]
so_path = compile_so(path, aoti_files, path)
if device == "cpu":
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
elif device == "cuda" or device.startswith("cuda:"):
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
else:
raise RuntimeError("Unsupported device " + device)
def optimized(*args, **kwargs): # type: ignore[no-untyped-def]
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
call_spec = self.loader.get_call_spec() # type: ignore[attr-defined]
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
flat_outputs = self.loader.run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)
return optimized
def get_metadata(self) -> Dict[str, str]:
return self.loader.get_metadata() # type: ignore[attr-defined]
def load_package(path: str, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
if not path.endswith(".pt2"):
raise RuntimeError("Unable to load package. Path must be a .pt2 file.")
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
return AOTICompiledModel(loader)

View File

@ -69,6 +69,7 @@
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_package/pybind.h>
#include <torch/csrc/inductor/aoti_runner/pybind.h>
#include <torch/csrc/instruction_counter/Module.h>
#include <torch/csrc/jit/python/init.h>
@ -1687,6 +1688,7 @@ PyObject* initModule() {
torch::python::init_bindings(module);
torch::lazy::initLazyBindings(module);
torch::inductor::initAOTIRunnerBindings(module);
torch::inductor::initAOTIPackageBindings(module);
#ifdef USE_ITT
torch::profiler::initIttBindings(module);
#endif

View File

@ -0,0 +1,411 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#include <fmt/format.h>
#include <miniz.h>
#include <nlohmann/json.hpp>
#include <fstream>
#include <iostream>
// TODO: Investigate why this is necessary, but fixes build problems in FRL
#if __has_include("filesystem")
#include <filesystem>
namespace fs = std::filesystem;
#else
#include <experimental/filesystem>
namespace fs = std::experimental::filesystem;
#endif
#ifndef _WIN32
#include <sys/stat.h>
#endif
// TODO: C++17 has the filesystem header, which may replace these
#ifdef _WIN32
// On Windows, the POSIX implementations are considered deprecated. We simply
// map to the newer variant.
#include <direct.h>
#include <io.h>
#include <process.h>
#define access _access
#define F_OK 0
#else
#include <sys/types.h>
#include <unistd.h>
#endif
namespace {
bool file_exists(std::string& path) {
#ifdef _WIN32
return fs::exists(path);
#else
struct stat rc;
return lstat(path.c_str(), &rc) == 0;
#endif
}
std::string create_temp_dir() {
#ifdef _WIN32
throw std::runtime_error("Not implemented");
#else
std::string temp_dir = "/tmp/XXXXXX";
if (mkdtemp(temp_dir.data()) == nullptr) {
throw std::runtime_error(
std::string("Failed to create temporary directory: ") +
strerror(errno));
}
return temp_dir;
#endif
}
} // namespace
namespace torch::inductor {
const nlohmann::json& AOTIModelPackageLoader::load_json_file(
std::string json_path) {
if (!file_exists(json_path)) {
throw std::runtime_error("File found: " + json_path);
}
std::ifstream json_file(json_path);
TORCH_CHECK(json_file.is_open());
static nlohmann::json json_obj;
json_file >> json_obj;
return json_obj;
}
void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) {
// Parse metadata json file (if it exists) into the metadata_ map
size_t lastindex = cpp_filename.find_last_of('.');
std::string metadata_json_path =
cpp_filename.substr(0, lastindex) + "_metadata.json";
const nlohmann::json metadata_json_obj = load_json_file(metadata_json_path);
for (auto& item : metadata_json_obj.items()) {
metadata_[item.key()] = item.value().get<std::string>();
}
}
std::tuple<std::string, std::string> AOTIModelPackageLoader::
get_cpp_compile_command(
const std::string& filename,
const std::vector<std::string>& sources,
const nlohmann::json& compile_options,
const std::string& output_dir = "") {
// Construct the cpp command
std::string compiler = compile_options["compiler"].get<std::string>();
bool compile_only = compile_options["compile_only"].get<bool>();
std::string source_args = "";
for (const std::string& source : sources) {
source_args += source + " ";
}
std::string file_ext = compile_only ? ".o" : ".so";
std::string target_file = output_dir + filename + file_ext;
std::string cflags_args = "";
for (auto& arg : compile_options["cflags"]) {
cflags_args += "-" + arg.get<std::string>() + " ";
}
std::string definitions_args = "";
for (auto& arg : compile_options["definitions"]) {
definitions_args += "-D " + arg.get<std::string>() + " ";
}
std::string include_dirs_args = "";
for (auto& arg : compile_options["include_dirs"]) {
include_dirs_args += "-I" + arg.get<std::string>() + " ";
}
std::string ldflags_args = "";
for (auto& arg : compile_options["ldflags"]) {
ldflags_args += "-" + arg.get<std::string>() + " ";
}
std::string libraries_dirs_args = "";
for (auto& arg : compile_options["libraries_dirs"]) {
libraries_dirs_args += "-L" + arg.get<std::string>() + " ";
}
std::string libraries_args = "";
for (auto& arg : compile_options["libraries"]) {
libraries_args += "-l" + arg.get<std::string>() + " ";
}
std::string passthrough_parameters_args = "";
for (auto& arg : compile_options["passthrough_args"]) {
passthrough_parameters_args += arg.get<std::string>() + " ";
}
std::string compile_only_arg = compile_only ? "-c" : "";
std::string cmd = fmt::format(
"{} {} {} {} {} {} {} {} {} {} -o {}",
compiler,
source_args,
definitions_args,
cflags_args,
include_dirs_args,
passthrough_parameters_args,
ldflags_args,
libraries_args,
libraries_dirs_args,
compile_only_arg,
target_file);
return std::make_tuple(cmd, target_file);
}
bool AOTIModelPackageLoader::recursive_mkdir(const std::string& dir) {
// Creates directories recursively, copied from jit_utils.cpp
// Check if current dir exists
const char* p_dir = dir.c_str();
const bool dir_exists = (access(p_dir, F_OK) == 0);
if (dir_exists) {
return true;
}
// Try to create current directory
#ifdef _WIN32
int ret = _mkdir(dir.c_str());
#else
int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
#endif
// Success
if (ret == 0) {
return true;
}
// Find folder separator and check if we are at the top
auto pos = dir.find_last_of("/\\");
if (pos == std::string::npos) {
return false;
}
// Try to create parent directory
if (!(recursive_mkdir(dir.substr(0, pos)))) {
return false;
}
// Try to create complete path again
#ifdef _WIN32
ret = _mkdir(dir.c_str());
#else
ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
#endif
return ret == 0;
}
std::string AOTIModelPackageLoader::compile_so(
const std::string& cpp_filename,
const std::string& consts_filename) {
// Compile the cpp file into a .so
size_t lastindex = cpp_filename.find_last_of('.');
std::string filename = cpp_filename.substr(0, lastindex);
std::string compile_flags_path = filename + "_compile_flags.json";
const nlohmann::json compile_flags = load_json_file(compile_flags_path);
auto compile_result =
get_cpp_compile_command(filename, {cpp_filename}, compile_flags);
std::string compile_cmd = std::get<0>(compile_result);
std::string output_o = std::get<1>(compile_result);
std::string linker_flags_path =
cpp_filename.substr(0, lastindex) + "_linker_flags.json";
const nlohmann::json linker_flags = load_json_file(linker_flags_path);
auto link_result = get_cpp_compile_command(
filename, {output_o, consts_filename}, linker_flags);
std::string link_cmd = std::get<0>(link_result);
std::string output_so = std::get<1>(link_result);
// Run the commands to generate a .so file
int status = system(compile_cmd.c_str());
if (status != 0) {
throw std::runtime_error("Failed to compile cpp file.");
}
status = system(link_cmd.c_str());
if (status != 0) {
throw std::runtime_error("Failed to link files.");
}
// Move the mmapped weights onto the .so
std::string serialized_weights_path = filename + "_serialized_weights.bin";
if (file_exists(serialized_weights_path)) {
std::ifstream serialized_weights_file(
serialized_weights_path, std::ios::binary);
if (!serialized_weights_file.is_open()) {
throw std::runtime_error("Failed to open serialized weights file");
}
std::vector<char> serialized_weights(
(std::istreambuf_iterator<char>(serialized_weights_file)),
std::istreambuf_iterator<char>());
serialized_weights_file.close();
std::ofstream output_so_file(output_so, std::ios::binary | std::ios::app);
if (!output_so_file.is_open()) {
throw std::runtime_error("Failed to open output .so file");
}
// Page align the weights
std::streampos so_size = output_so_file.tellp();
std::vector<char> padding(16384 - so_size % 16384, ' ');
output_so_file.write(
padding.data(), static_cast<std::streamsize>(padding.size()));
output_so_file.write(
serialized_weights.data(),
static_cast<std::streamsize>(serialized_weights.size()));
output_so_file.close();
}
return output_so;
}
AOTIModelPackageLoader::AOTIModelPackageLoader(
const std::string& model_package_path)
: AOTIModelPackageLoader(model_package_path, "model") {}
AOTIModelPackageLoader::AOTIModelPackageLoader(
const std::string& model_package_path,
const std::string& model_name = "model") {
// Extract all files within the zipfile to a temporary directory
mz_zip_archive zip_archive;
memset(&zip_archive, 0, sizeof(zip_archive));
if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) {
throw std::runtime_error(
std::string("Failed to initialize zip archive: ") +
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
}
std::string temp_dir = create_temp_dir();
std::string so_filename = "";
std::string cpp_filename = "";
std::string consts_filename = "";
std::string found_filenames = ""; // Saving for bookkeeping
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
uint32_t filename_len =
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
if (filename_len == 0) {
throw std::runtime_error("Failed to read filename");
}
char* filename = new char[filename_len + 1];
if (!mz_zip_reader_get_filename(&zip_archive, i, filename, filename_len)) {
throw std::runtime_error("Failed to read filename");
}
std::string filename_str(filename);
found_filenames += filename_str;
found_filenames += " ";
// Only compile files in the specified model directory
std::string model_directory = "data/aotinductor/" + model_name;
if (filename_str.length() >= model_directory.length() &&
filename_str.substr(0, model_directory.length()) == model_directory) {
std::string output_path_str = temp_dir;
output_path_str += "/";
output_path_str += filename_str;
// Create the parent directory if it doesn't exist
size_t parent_path_idx = output_path_str.find_last_of("/\\");
if (parent_path_idx == std::string::npos) {
throw std::runtime_error(
"Failed to find parent path in " + output_path_str);
}
std::string parent_path = output_path_str.substr(0, parent_path_idx);
if (!recursive_mkdir(parent_path.c_str())) {
throw std::runtime_error(fmt::format(
"Failed to create directory {}: {}", parent_path, strerror(errno)));
}
// Extracts file to the temp directory
mz_zip_reader_extract_file_to_file(
&zip_archive, filename, output_path_str.c_str(), 0);
// Save the file for bookkeeping
size_t extension_idx = output_path_str.find_last_of('.');
if (extension_idx != std::string::npos) {
std::string filename_extension = output_path_str.substr(extension_idx);
if (filename_extension == ".cpp") {
cpp_filename = output_path_str;
}
if (filename_extension == ".o") {
consts_filename = output_path_str;
}
if (filename_extension == ".so") {
so_filename = output_path_str;
}
}
}
}
// Close the zip archive as we have extracted all files to the temp directory
if (!mz_zip_reader_end(&zip_archive)) {
throw std::runtime_error(
std::string("Failed to close zip archive: {}") +
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
}
if (cpp_filename.empty() && so_filename.empty()) {
throw std::runtime_error(
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
found_filenames);
}
// Compile the .so
std::string so_path = !so_filename.empty()
? so_filename
: compile_so(cpp_filename, consts_filename);
// Load metadata which can be queried by user
load_metadata(cpp_filename);
// Construct the runner depending on the device information
std::string device = metadata_["AOTI_DEVICE_KEY"];
if (device.empty()) {
throw std::runtime_error("No device information found.");
}
std::unordered_map<std::string, CreateAOTIModelRunnerFunc>
registered_aoti_runner = getAOTIModelRunnerRegistry();
if (registered_aoti_runner.find(device) == registered_aoti_runner.end()) {
throw std::runtime_error("Unsupported device found: " + device);
}
runner_ = registered_aoti_runner[device](so_path, 1, device, "");
std::remove(temp_dir.c_str());
}
AOTIModelContainerRunner* AOTIModelPackageLoader::get_runner() {
return runner_.get();
}
std::vector<at::Tensor> AOTIModelPackageLoader::run(
std::vector<at::Tensor>& inputs) {
return runner_->run(inputs);
}
std::unordered_map<std::string, std::string> AOTIModelPackageLoader::
get_metadata() {
return metadata_;
}
std::vector<std::string> AOTIModelPackageLoader::get_call_spec() {
return runner_->get_call_spec();
}
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,40 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once
#include <ATen/Tensor.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <nlohmann/json.hpp>
namespace torch::inductor {
class TORCH_API AOTIModelPackageLoader {
public:
AOTIModelPackageLoader(const std::string& model_package_path);
AOTIModelPackageLoader(
const std::string& model_package_path,
const std::string& model_name);
AOTIModelContainerRunner* get_runner();
std::unordered_map<std::string, std::string> get_metadata();
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs);
std::vector<std::string> get_call_spec();
private:
std::unique_ptr<AOTIModelContainerRunner> runner_;
std::unordered_map<std::string, std::string> metadata_;
void load_metadata(const std::string& cpp_filename);
std::string compile_so(
const std::string& cpp_filename,
const std::string& consts_filename);
const nlohmann::json& load_json_file(std::string json_path);
std::tuple<std::string, std::string> get_cpp_compile_command(
const std::string& filename,
const std::vector<std::string>& sources,
const nlohmann::json& compile_options,
const std::string& output_dir);
bool recursive_mkdir(const std::string& dir);
};
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,24 @@
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <torch/csrc/inductor/aoti_runner/pybind.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::inductor {
void initAOTIPackageBindings(PyObject* module) {
auto rootModule = py::handle(module).cast<py::module>();
auto m = rootModule.def_submodule("_aoti");
py::class_<AOTIModelPackageLoader>(m, "AOTIModelPackageLoader")
.def(py::init<const std::string&, const std::string&>())
.def(py::init<const std::string&>())
.def("get_metadata", &AOTIModelPackageLoader::get_metadata)
.def("run", &AOTIModelPackageLoader::run)
.def("get_call_spec", &AOTIModelPackageLoader::get_call_spec);
}
} // namespace torch::inductor

View File

@ -0,0 +1,7 @@
#include <torch/csrc/python_headers.h>
namespace torch::inductor {
void initAOTIPackageBindings(PyObject* module);
} // namespace torch::inductor

View File

@ -82,7 +82,7 @@ class TORCH_API AOTIModelContainerRunner {
std::unique_ptr<torch::aot_inductor::ProxyExecutor> proxy_executor_;
};
using CreateAOTIModelRunnerFunc = std::shared_ptr<AOTIModelContainerRunner> (*)(
using CreateAOTIModelRunnerFunc = std::unique_ptr<AOTIModelContainerRunner> (*)(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,

View File

@ -17,5 +17,21 @@ std::vector<at::Tensor> AOTIModelContainerRunnerCpu::run(
return AOTIModelContainerRunner::run(inputs);
}
namespace {
std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_cpu(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir) {
if (device_str != "cpu") {
throw std::runtime_error("Incorrect device passed to aoti_runner_cpu");
}
return std::make_unique<AOTIModelContainerRunnerCpu>(
model_so_path, num_models);
}
} // namespace
RegisterAOTIModelRunner register_cpu_runner("cpu", &create_aoti_runner_cpu);
} // namespace torch::inductor
#endif

View File

@ -30,5 +30,18 @@ std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run_with_cuda_stream(
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream.stream()));
}
namespace {
std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_cuda(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir) {
return std::make_unique<AOTIModelContainerRunnerCuda>(
model_so_path, num_models, device_str, cubin_dir);
}
} // namespace
RegisterAOTIModelRunner register_cuda_runner("cuda", &create_aoti_runner_cuda);
} // namespace torch::inductor
#endif

View File

@ -55,6 +55,39 @@ namespace fs = std::filesystem;
namespace fs = std::experimental::filesystem;
#endif
#ifndef _WIN32
#include <sys/stat.h>
#endif
// HACK for failed builds in ARVR, where it cannot find these symbols within
// std::experimental::filesystem
namespace {
fs::path get_current_path() {
#if __has_include("filesystem")
return fs::current_path();
#else
throw std::runtime_error("Not implemented");
#endif
}
bool file_exists(std::string& path) {
#ifdef _WIN32
return fs::exists(path);
#else
struct stat rc;
return lstat(path.c_str(), &rc) == 0;
#endif
}
bool create_directories(const std::string& path) {
#if __has_include("filesystem")
return fs::create_directories(path);
#else
throw std::runtime_error("Not implemented");
#endif
}
} // namespace
using namespace torch::aot_inductor;
namespace {
@ -1004,14 +1037,14 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
at::Tensor* t = tensor_handle_to_tensor_pointer(self);
#ifndef C10_MOBILE
// Save tensor to tmp .pt file for tensors and can be torch.load'ed later
std::string cwd = fs::current_path().string();
std::string cwd = get_current_path().string();
std::string tmp_folder = cwd + "/tmp/aoti_torch/";
if (!fs::exists(tmp_folder)) {
if (!file_exists(tmp_folder)) {
std::cout
<< "aoti_torch_save_tensor_handle: Path does not exist, creating it..."
<< tmp_folder << std::endl;
if (!fs::create_directories(tmp_folder)) {
if (!create_directories(tmp_folder)) {
std::cout << "aoti_torch_save_tensor_handle: Error creating directory: "
<< tmp_folder << std::endl;
return;