Compare commits

...

1 Commits

Author SHA1 Message Date
4a45e48a3c [aoti] Add cpp loader
ghstack-source-id: ccb800e2667132afdd1ab6f2b974be635f581e24
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134865
2024-08-30 14:26:09 -07:00
14 changed files with 518 additions and 81 deletions

View File

@ -466,6 +466,7 @@ lazy_tensor_core_python_sources = [
]
inductor_core_resources = [
"torch/csrc/inductor/aoti_package/model_package_loader.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
@ -840,6 +841,7 @@ libtorch_python_core_sources = [
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/inductor/aoti_package/pybind.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
"torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp",

View File

@ -1318,6 +1318,7 @@ def main():
"include/torch/csrc/distributed/autograd/rpc_messages/*.h",
"include/torch/csrc/dynamo/*.h",
"include/torch/csrc/inductor/*.h",
"include/torch/csrc/inductor/aoti_package/*.h",
"include/torch/csrc/inductor/aoti_runner/*.h",
"include/torch/csrc/inductor/aoti_runtime/*.h",
"include/torch/csrc/inductor/aoti_torch/*.h",

View File

@ -1,11 +1,14 @@
# Owner(s): ["module: inductor"]
import copy
import sys
import tempfile
import unittest
from dataclasses import dataclass
from typing import Callable
import torch
from torch._inductor import config
from torch._inductor.package import load_package
from torch._inductor.package import load_package, package_aoti
from torch._inductor.test_case import TestCase
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import IS_FBCODE
@ -23,7 +26,13 @@ except (unittest.SkipTest, ImportError) as e:
raise
def compile(model, example_inputs, dynamic_shapes, options, device):
@dataclass
class CompiledResult:
loader: torch._C._aoti.AOTIModelPackageLoader
compiled_model: Callable
def compile(model, example_inputs, dynamic_shapes, options, device) -> CompiledResult:
ep = torch.export.export(
model,
example_inputs,
@ -31,9 +40,12 @@ def compile(model, example_inputs, dynamic_shapes, options, device):
strict=False,
)
gm = ep.module()
package_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type]
compiled_model = load_package(package_path, device)
return compiled_model
aoti_files = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type]
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_path = package_aoti(f.name, aoti_files)
compiled_model = load_package(package_path)
loader = torch._C._aoti.AOTIModelPackageLoader(package_path) # type: ignore[call-arg]
return CompiledResult(loader, compiled_model)
def check_model(
@ -45,7 +57,7 @@ def check_model(
disable_constraint_solver=False,
atol=None,
rtol=None,
):
) -> CompiledResult:
with torch.no_grad(), config.patch(
{
"aot_inductor.package": True,
@ -59,7 +71,7 @@ def check_model(
expected = ref_model(*ref_inputs)
torch.manual_seed(0)
compiled_model = compile(
compiled_result = compile(
model,
example_inputs,
dynamic_shapes,
@ -67,9 +79,10 @@ def check_model(
self.device,
)
actual = compiled_model(*example_inputs)
actual = compiled_result.compiled_model(*example_inputs)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
return compiled_result
class AOTInductorTestsTemplate:
@ -99,6 +112,28 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(), example_inputs)
def test_metadata(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
metadata = {"dummy": "moo"}
compiled_result = self.check_model(
Model(), example_inputs, options={"aot_inductor.metadata": metadata}
)
loaded_metadata = compiled_result.loader.get_metadata() # type: ignore[attr-defined]
self.assertEqual(loaded_metadata.get("dummy"), "moo")
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)

View File

@ -18,3 +18,6 @@ def alloc_tensor_by_stealing_from_void_ptr(
class AOTIModelContainerRunnerCpu: ...
class AOTIModelContainerRunnerCuda: ...
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
class AOTIModelPackageLoader: ...

View File

@ -1706,10 +1706,23 @@ class AotCodeCompiler:
# Currently, this only support serializing extern nodes in fbcode
# Eventually, we should also have a serializer for OSS.
if serialized_extern_kernel_nodes:
output_json = os.path.splitext(input_path)[0] + ".json"
with open(output_json, "w") as f:
extern_kernel_nodes_json = os.path.splitext(input_path)[0] + ".json"
with open(extern_kernel_nodes_json, "w") as f:
f.write(serialized_extern_kernel_nodes)
metadata = config.aot_inductor.metadata
metadata["AOTI_DEVICE_KEY"] = "cuda" if cuda else "cpu"
# Save user provided metadata
meta_json = os.path.splitext(input_path)[0] + "_metadata.json"
for k, v in config.aot_inductor.metadata.items():
assert isinstance(k, str) and isinstance(
v, (str)
), "Metadata must only contain strings"
with open(meta_json, "w") as f:
f.write(json.dumps(config.aot_inductor.metadata))
output_so = (
config.aot_inductor.output_path
if specified_so_name
@ -1860,8 +1873,6 @@ class AotCodeCompiler:
linker_flags = os.path.splitext(input_path)[0] + "_linker_flags.json"
so_build_options.save_flags_to_file(linker_flags)
from torch._inductor.package import package_aoti
if use_mmap_weights:
weight_file = (
os.path.splitext(input_path)[0] + "_serialized_weights.bin"
@ -1870,8 +1881,7 @@ class AotCodeCompiler:
f_weights.write(serialized_weights)
f_weights.write(struct.pack("q", magic_number))
archive_path = package_aoti(os.path.split(input_path)[0])
return archive_path
return os.path.split(input_path)[0]
else:
output_name, output_dir = get_name_and_dir_from_output_file_path(
output_so

View File

@ -1129,6 +1129,16 @@ def compile_fx_aot(
if config_patches is None
else {**config_patches, "cpp_wrapper": True}
)
if config_patches.get("aot_inductor.package", False) or config.aot_inductor.package:
assert (
"aot_inductor.output_path" not in config_patches
and not config.aot_inductor.output_path
), (
"Specifying an aot_inductor.output_path is prohibited, as this "
"should be done subsequently in a call to package_aoti()"
)
if (
"aot_inductor.output_path" not in config_patches
and not config.aot_inductor.output_path

View File

@ -946,6 +946,9 @@ class aot_inductor:
package: bool = False
# Dictionary of metadata users might want to save to pass to the runtime.
metadata: Dict[str, str] = {}
class cuda:
# CUDA arch to use for CUDA template kernel compilation.

View File

@ -1,24 +1,25 @@
import glob
import json
import logging
import os
import shlex
import subprocess
import tempfile
import zipfile
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
import torch
import torch._inductor
import torch.utils._pytree as pytree
from torch._inductor import config, exc
from torch._inductor import exc
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export._tree_utils import reorder_kwargs
from .build_package import build_package_contents
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
log = logging.getLogger(__name__)
class PT2ArchiveWriter:
def __init__(self, archive_path: str) -> None:
self.archive_path: str = archive_path
@ -154,84 +155,61 @@ def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
return output_so
def package_aoti(aoti_output_dir: str) -> str:
def package_aoti(archive_file: str, aoti_files: Union[str, Dict[str, str]]) -> str:
"""
Saves the AOTInductor generated files to the PT2Archive format.
Args:
archive_file: The file name to save the package to.
aoti_files: This can either be a singular path to a directory containing
the AOTInductor files, or a dictionary mapping the model name to the
path to its AOTInductor generated files.
"""
if isinstance(aoti_files, str):
aoti_files = {"model": aoti_files}
# Add a makefile and python script
build_package_filename = "build_package.py"
with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f:
f.write(build_package_contents)
assert isinstance(aoti_files, dict)
assert archive_file.endswith(".pt2")
with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f:
f.write(f"all:\n\tpython3 {build_package_filename}\n")
# Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
if config.aot_inductor.output_path.endswith(".so"):
raise RuntimeError(
"Unable to save package as a .so. It should be a .pt2 format or a directory."
for model_name, aoti_output_dir in aoti_files.items():
log.debug(
"Packaging AOTInductor files from %s with model name, %s",
aoti_output_dir,
model_name,
)
elif config.aot_inductor.output_path.endswith(".pt2"):
# Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
archive_path = config.aot_inductor.output_path
with PT2ArchiveWriter(archive_file) as archive_writer:
for root, dirs, files in os.walk(aoti_output_dir):
for file in files:
log.debug(
"Saving file %s to archive in %s%s/%s",
os.path.join(root, file),
AOTINDUCTOR_DIR,
model_name,
file,
)
archive_writer.write_file(
f"{AOTINDUCTOR_DIR}{model_name}/{file}",
os.path.join(root, file),
)
with PT2ArchiveWriter(archive_path) as archive_writer:
package_files = glob.glob(f"{aoti_output_dir}/*")
for path in package_files:
filename = os.path.basename(path)
archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path)
return archive_path
else:
# Directly put the files into the directory, without any archiving
return aoti_output_dir
return archive_file
def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg]
if path.endswith(".so"):
raise RuntimeError(
"Unable to load .so. It should be a .pt2 format or a directory."
)
def load_package(path: str, model_name: str = "model") -> Callable: # type: ignore[type-arg]
if not path.endswith(".pt2"):
raise RuntimeError("Unable to load package. Path must be a .pt2 file.")
elif path.endswith(".pt2"):
so_path = os.path.splitext(path)[0]
with PT2ArchiveReader(path) as archive_reader:
file_names = archive_reader.get_file_names()
with tempfile.TemporaryDirectory() as tmp_dir:
archive_reader.extractall(tmp_dir)
file_names = archive_reader.get_file_names()
aoti_files = [
file for file in file_names if file.startswith(AOTINDUCTOR_DIR)
]
so_path = compile_so(tmp_dir, aoti_files, so_path)
else:
assert os.path.isdir(path), "Must specify a directory or a .pt2 file"
aoti_files = [
os.path.join(root, file)
for root, dirs, files in os.walk(path)
for file in files
]
so_path = compile_so(path, aoti_files, path)
if device == "cpu":
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
elif device == "cuda" or device.startswith("cuda:"):
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
else:
raise RuntimeError("Unsupported device " + device)
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
def optimized(*args, **kwargs): # type: ignore[no-untyped-def]
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
call_spec = loader.get_call_spec() # type: ignore[attr-defined]
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
flat_outputs = loader.run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)
return optimized

View File

@ -69,6 +69,7 @@
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_package/pybind.h>
#include <torch/csrc/inductor/aoti_runner/pybind.h>
#include <torch/csrc/instruction_counter/Module.h>
#include <torch/csrc/jit/python/init.h>
@ -1687,6 +1688,7 @@ PyObject* initModule() {
torch::python::init_bindings(module);
torch::lazy::initLazyBindings(module);
torch::inductor::initAOTIRunnerBindings(module);
torch::inductor::initAOTIPackageBindings(module);
#ifdef USE_ITT
torch::profiler::initIttBindings(module);
#endif

View File

@ -0,0 +1,315 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <fmt/format.h>
#include <miniz.h>
#include <nlohmann/json.hpp>
#include <fstream>
#include <iostream>
// TODO: Investigate why this is necessary, but fixes build problems in FRL
#if __has_include("filesystem")
#include <filesystem>
namespace fs = std::filesystem;
#else
#include <experimental/filesystem>
namespace fs = std::experimental::filesystem;
#endif
#ifndef _WIN32
#include <sys/stat.h>
#endif
namespace {
bool file_exists(std::string& path) {
#ifdef _WIN32
return fs::exists(path);
#else
struct stat rc;
return lstat(path.c_str(), &rc) == 0;
#endif
}
} // namespace
namespace torch::inductor {
const nlohmann::json& AOTIModelPackageLoader::load_json_file(
std::string json_path) {
if (!file_exists(json_path)) {
throw std::runtime_error(fmt::format("File found: {}", json_path));
}
std::ifstream json_file(json_path);
TORCH_CHECK(json_file.is_open());
static nlohmann::json json_obj;
json_file >> json_obj;
return json_obj;
}
void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) {
// Parse metadata json file (if it exists) into the metadata_ map
size_t lastindex = cpp_filename.find_last_of('.');
std::string metadata_json_path =
cpp_filename.substr(0, lastindex) + "_metadata.json";
const nlohmann::json metadata_json_obj = load_json_file(metadata_json_path);
for (auto& item : metadata_json_obj.items()) {
metadata_[item.key()] = item.value().get<std::string>();
}
}
std::tuple<std::string, std::string> AOTIModelPackageLoader::
get_cpp_compile_command(
fs::path filename,
const std::vector<std::string>& sources,
const nlohmann::json& compile_options,
const std::string& output_dir = "") {
// Construct the cpp command
std::string compiler = compile_options["compiler"].get<std::string>();
bool compile_only = compile_options["compile_only"].get<bool>();
std::string source_args = "";
for (const std::string& source : sources) {
source_args += source + " ";
}
std::string file_ext = compile_only ? ".o" : ".so";
fs::path target_file = output_dir / filename.replace_extension(file_ext);
std::string cflags_args = "";
for (auto& arg : compile_options["cflags"]) {
cflags_args += "-" + arg.get<std::string>() + " ";
}
std::string definitions_args = "";
for (auto& arg : compile_options["definitions"]) {
definitions_args += "-D " + arg.get<std::string>() + " ";
}
std::string include_dirs_args = "";
for (auto& arg : compile_options["include_dirs"]) {
include_dirs_args += "-I" + arg.get<std::string>() + " ";
}
std::string ldflags_args = "";
for (auto& arg : compile_options["ldflags"]) {
ldflags_args += "-" + arg.get<std::string>() + " ";
}
std::string libraries_dirs_args = "";
for (auto& arg : compile_options["libraries_dirs"]) {
libraries_dirs_args += "-L" + arg.get<std::string>() + " ";
}
std::string libraries_args = "";
for (auto& arg : compile_options["libraries"]) {
libraries_args += "-l" + arg.get<std::string>() + " ";
}
std::string passthrough_parameters_args = "";
for (auto& arg : compile_options["passthrough_args"]) {
passthrough_parameters_args += arg.get<std::string>() + " ";
}
std::string compile_only_arg = compile_only ? "-c" : "";
std::string cmd = fmt::format(
"{} {} {} {} {} {} {} {} {} {} -o {}",
compiler,
source_args,
definitions_args,
cflags_args,
include_dirs_args,
passthrough_parameters_args,
ldflags_args,
libraries_args,
libraries_dirs_args,
compile_only_arg,
target_file.string());
return std::make_tuple(cmd, target_file.string());
}
std::string AOTIModelPackageLoader::compile_so(
const std::string& cpp_filename,
const std::string& consts_filename) {
// Compile the cpp file into a .so
size_t lastindex = cpp_filename.find_last_of('.');
std::string filename = cpp_filename.substr(0, lastindex);
std::string compile_flags_path = filename + "_compile_flags.json";
const nlohmann::json compile_flags = load_json_file(compile_flags_path);
auto compile_result =
get_cpp_compile_command(filename, {cpp_filename}, compile_flags);
std::string compile_cmd = std::get<0>(compile_result);
std::string output_o = std::get<1>(compile_result);
std::string linker_flags_path =
cpp_filename.substr(0, lastindex) + "_linker_flags.json";
const nlohmann::json linker_flags = load_json_file(linker_flags_path);
auto link_result = get_cpp_compile_command(
filename, {output_o, consts_filename}, linker_flags);
std::string link_cmd = std::get<0>(link_result);
std::string output_so = std::get<1>(link_result);
// Run the commands to generate a .so file
int status = system(compile_cmd.c_str());
if (status != 0) {
throw std::runtime_error("Failed to compile cpp file.");
}
status = system(link_cmd.c_str());
if (status != 0) {
throw std::runtime_error("Failed to link files.");
}
// Move the mmapped weights onto the .so
std::string serialized_weights_path = filename + "_serialized_weights.bin";
if (file_exists(serialized_weights_path)) {
std::ifstream serialized_weights_file(
serialized_weights_path, std::ios::binary);
if (!serialized_weights_file.is_open()) {
throw std::runtime_error("Failed to open serialized weights file");
}
std::vector<char> serialized_weights(
(std::istreambuf_iterator<char>(serialized_weights_file)),
std::istreambuf_iterator<char>());
serialized_weights_file.close();
std::ofstream output_so_file(output_so, std::ios::binary | std::ios::app);
if (!output_so_file.is_open()) {
throw std::runtime_error("Failed to open output .so file");
}
// Page align the weights
std::streampos so_size = output_so_file.tellp();
std::vector<char> padding(16384 - so_size % 16384, ' ');
output_so_file.write(
padding.data(), static_cast<std::streamsize>(padding.size()));
output_so_file.write(
serialized_weights.data(),
static_cast<std::streamsize>(serialized_weights.size()));
output_so_file.close();
}
return output_so;
}
AOTIModelPackageLoader::AOTIModelPackageLoader(
const std::string& model_package_path)
: AOTIModelPackageLoader(model_package_path, "model") {}
AOTIModelPackageLoader::AOTIModelPackageLoader(
const std::string& model_package_path,
const std::string& model_name = "model") {
// Extract all files within the zipfile to a temporary directory
mz_zip_archive zip_archive;
memset(&zip_archive, 0, sizeof(zip_archive));
if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) {
throw std::runtime_error(fmt::format(
"Failed to initialize zip archive: {}",
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))));
}
fs::path temp_dir = fs::temp_directory_path() / std::tmpnam(nullptr);
std::filesystem::create_directories(temp_dir);
std::string cpp_filename = "";
std::string consts_filename = "";
std::string found_filenames = ""; // Saving for bookkeeping
for (uint i = 0; i < zip_archive.m_total_files; i++) {
uint filename_len = mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
if (filename_len == 0) {
throw std::runtime_error("Failed to read filename");
}
char* filename = new char[filename_len + 1];
if (!mz_zip_reader_get_filename(&zip_archive, i, filename, filename_len)) {
throw std::runtime_error("Failed to read filename");
}
fs::path filepath(filename);
if (filepath.parent_path() !=
fmt::format("data/aotinductor/{}", model_name)) {
continue;
}
found_filenames += filename;
found_filenames += "\n";
fs::path output_path = temp_dir / filename;
fs::create_directories(output_path.parent_path());
mz_zip_reader_extract_file_to_file(
&zip_archive, filename, output_path.c_str(), 0);
if (output_path.extension() == ".cpp") {
cpp_filename = output_path;
}
if (output_path.extension() == ".o") {
consts_filename = output_path;
}
}
// Close the zip archive as we have extracted all files to the temp directory
mz_zip_reader_end(&zip_archive);
if (cpp_filename.empty()) {
throw std::runtime_error(fmt::format(
"No AOTInductor generate cpp file found in zip archive. Loaded the following:\n{}",
found_filenames));
}
// Compile the .so
std::string so_path = compile_so(cpp_filename, consts_filename);
// Load metadata which can be queried by user
load_metadata(cpp_filename);
// Construct the runner depending on the device information
std::string device = metadata_["AOTI_DEVICE_KEY"];
if (device.empty()) {
throw std::runtime_error("No device information found.");
#ifdef USE_CUDA
} else if (device == "cuda") {
runner_ = new AOTIModelContainerRunnerCuda(so_path);
#endif
} else if (device == "cpu") {
runner_ = new AOTIModelContainerRunnerCpu(so_path);
} else {
throw std::runtime_error(
fmt::format("Unsupported device found: {}", device));
}
fs::remove_all(temp_dir);
}
AOTIModelContainerRunner* AOTIModelPackageLoader::get_runner() {
return runner_;
}
std::vector<at::Tensor> AOTIModelPackageLoader::run(
std::vector<at::Tensor>& inputs) {
return runner_->run(inputs);
}
std::unordered_map<std::string, std::string> AOTIModelPackageLoader::
get_metadata() {
return metadata_;
}
std::vector<std::string> AOTIModelPackageLoader::get_call_spec() {
return runner_->get_call_spec();
}
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,46 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once
#include <ATen/Tensor.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <nlohmann/json.hpp>
#if __has_include("filesystem")
#include <filesystem>
namespace fs = std::filesystem;
#else
#include <experimental/filesystem>
namespace fs = std::experimental::filesystem;
#endif
namespace torch::inductor {
class TORCH_API AOTIModelPackageLoader {
public:
AOTIModelPackageLoader(const std::string& model_package_path);
AOTIModelPackageLoader(
const std::string& model_package_path,
const std::string& model_name);
AOTIModelContainerRunner* get_runner();
std::unordered_map<std::string, std::string> get_metadata();
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs);
std::vector<std::string> get_call_spec();
private:
AOTIModelContainerRunner* runner_;
std::unordered_map<std::string, std::string> metadata_;
void load_metadata(const std::string& cpp_filename);
std::string compile_so(
const std::string& cpp_filename,
const std::string& consts_filename);
const nlohmann::json& load_json_file(std::string json_path);
std::tuple<std::string, std::string> get_cpp_compile_command(
fs::path filename,
const std::vector<std::string>& sources,
const nlohmann::json& compile_options,
const std::string& output_dir);
};
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,24 @@
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <torch/csrc/inductor/aoti_runner/pybind.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::inductor {
void initAOTIPackageBindings(PyObject* module) {
auto rootModule = py::handle(module).cast<py::module>();
auto m = rootModule.def_submodule("_aoti");
py::class_<AOTIModelPackageLoader>(m, "AOTIModelPackageLoader")
.def(py::init<const std::string&, const std::string&>())
.def(py::init<const std::string&>())
.def("get_metadata", &AOTIModelPackageLoader::get_metadata)
.def("run", &AOTIModelPackageLoader::run)
.def("get_call_spec", &AOTIModelPackageLoader::get_call_spec);
}
} // namespace torch::inductor

View File

@ -0,0 +1,7 @@
#include <torch/csrc/python_headers.h>
namespace torch::inductor {
void initAOTIPackageBindings(PyObject* module);
} // namespace torch::inductor

View File

@ -2,6 +2,7 @@
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>