[AOTI] codegen for static linkage (#157129)

Design doc: https://docs.google.com/document/d/1ncV7RpJ8xDwy8-_aCBfvZmpTTL824C-aoNPBLLVkOHM/edit?tab=t.0 (internal)

- Add codegen for static linkage
- refactor test code for test_compile_after_package tests

For now,  the following options must be used together with `"aot_inductor.compile_standalone": True`.
"aot_inductor.package_cpp_only": True,

Will change `"aot_inductor.package_cpp_only"` to be automatically set to True in followup PR.

```
python test/inductor/test_aot_inductor_package.py -k test_compile_after_package
python test/inductor/test_aot_inductor_package.py -k test_run_static_linkage_model
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157129
Approved by: https://github.com/desertfire
This commit is contained in:
Shangdi Yu
2025-07-10 16:03:50 +00:00
committed by PyTorch MergeBot
parent 9bdf87e891
commit 4781d72faa
10 changed files with 509 additions and 88 deletions

View File

@ -231,7 +231,8 @@ include_patterns = [
'c10/**/*.cpp', 'c10/**/*.cpp',
'c10/**/*.h', 'c10/**/*.h',
'torch/*.h', 'torch/*.h',
'torch/_inductor/codegen/aoti_runtime/interface.cpp', 'torch/_inductor/codegen/aoti_runtime/*.h',
'torch/_inductor/codegen/aoti_runtime/*.cpp',
'torch/csrc/*.h', 'torch/csrc/*.h',
'torch/csrc/*.cpp', 'torch/csrc/*.cpp',
'torch/csrc/**/*.h', 'torch/csrc/**/*.h',

View File

@ -1310,7 +1310,9 @@ def main() -> None:
"include/**/*.hpp", "include/**/*.hpp",
"include/*.cuh", "include/*.cuh",
"include/**/*.cuh", "include/**/*.cuh",
"csrc/inductor/aoti_runtime/model.h",
"_inductor/codegen/*.h", "_inductor/codegen/*.h",
"_inductor/codegen/aoti_runtime/*.h",
"_inductor/codegen/aoti_runtime/*.cpp", "_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld", "_inductor/script.ld",
"_export/serde/*.yaml", "_export/serde/*.yaml",

View File

@ -30,6 +30,20 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
try:
from test_static_linkage_utils import (
get_static_linkage_main_cpp_file,
get_static_linkage_makelist_file_cpu,
get_static_linkage_makelist_file_cuda,
)
except ImportError:
from .test_static_linkage_utils import (
get_static_linkage_main_cpp_file,
get_static_linkage_makelist_file_cpu,
get_static_linkage_makelist_file_cuda,
)
def skipif(predicate: Callable[[str, bool], bool], reason: str): def skipif(predicate: Callable[[str, bool], bool], reason: str):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
@ -126,6 +140,54 @@ class TestAOTInductorPackage(TestCase):
self.assertEqual(actual, expected, atol=atol, rtol=rtol) self.assertEqual(actual, expected, atol=atol, rtol=rtol)
return compiled_model return compiled_model
def check_package_cpp_only(self: TestCase) -> None:
"""
Check if cmake and make are available.
Skip self.package_cpp_only=False tests
"""
if not self.package_cpp_only:
raise unittest.SkipTest("Only meant to test cpp package")
if shutil.which("cmake") is None:
raise unittest.SkipTest("cmake is not available")
if shutil.which("make") is None:
raise unittest.SkipTest("make is not available")
def cmake_compile(self, model, example_inputs, options, tmp_dir):
"""
Exports model, compiles it using AOTInductor, extracts the
generated files to tmp_dir, and builds the C++ code using CMake and Make.
Returns:
- build_path (Path): Path to the CMake build directory containing the compiled binary.
- tmp_path (Path): Path to the extracted model source directory.
"""
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=options
)
with (
zipfile.ZipFile(package_path, "r") as zip_ref,
):
filenames = zip_ref.namelist()
prefix = filenames[0].split("/")[0]
zip_ref.extractall(tmp_dir)
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
self.assertTrue(tmp_path.exists())
# Create a build directory to run cmake
build_path = tmp_path / "build"
self.assertTrue(not build_path.exists())
build_path.mkdir()
custom_env = os.environ.copy()
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
subprocess.run(
["cmake", ".."],
cwd=build_path,
env=custom_env,
check=True,
)
subprocess.run(["make"], cwd=build_path, check=True)
return build_path, tmp_path
def test_add(self): def test_add(self):
class Model(torch.nn.Module): class Model(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):
@ -189,12 +251,7 @@ class TestAOTInductorPackage(TestCase):
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfXpu # build system may be different @skipIfXpu # build system may be different
def test_compile_after_package(self): def test_compile_after_package(self):
if not self.package_cpp_only: self.check_package_cpp_only()
raise unittest.SkipTest("Only meant to test cpp package")
if shutil.which("cmake") is None:
raise unittest.SkipTest("cmake is not available")
if shutil.which("make") is None:
raise unittest.SkipTest("make is not available")
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -217,39 +274,19 @@ class TestAOTInductorPackage(TestCase):
# Require kernels to be compiled into .o files # Require kernels to be compiled into .o files
"aot_inductor.embed_kernel_binary": True, "aot_inductor.embed_kernel_binary": True,
} }
ep = torch.export.export(model, example_inputs, strict=True)
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=options
)
with ( with (
tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as tmp_dir,
zipfile.ZipFile(package_path, "r") as zip_ref,
): ):
filenames = zip_ref.namelist() build_path, tmp_path = self.cmake_compile(
prefix = filenames[0].split("/")[0] model, example_inputs, options, tmp_dir
zip_ref.extractall(tmp_dir) )
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
self.assertTrue(tmp_path.exists())
if self.device == GPU_TYPE: if self.device == GPU_TYPE:
kernel_bin = get_kernel_bin_format(self.device) kernel_bin = get_kernel_bin_format(self.device)
self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}"))) self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}")))
# Check if .cubin.o files exist and use unique kernel names # Check if .cubin.o files exist and use unique kernel names
self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o"))) self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o")))
build_path = tmp_path / "build"
self.assertTrue(not build_path.exists())
# Create a build directory to run cmake
build_path.mkdir()
custom_env = os.environ.copy()
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
subprocess.run(
["cmake", ".."],
cwd=build_path,
env=custom_env,
)
subprocess.run(["make"], cwd=build_path)
# Check if the .so file was build successfully # Check if the .so file was build successfully
so_path = build_path / "libaoti_model.so" so_path = build_path / "libaoti_model.so"
self.assertTrue(so_path.exists()) self.assertTrue(so_path.exists())
@ -263,12 +300,7 @@ class TestAOTInductorPackage(TestCase):
def test_compile_after_package_multi_arch(self): def test_compile_after_package_multi_arch(self):
if self.device != GPU_TYPE: if self.device != GPU_TYPE:
raise unittest.SkipTest("Only meant to test GPU_TYPE") raise unittest.SkipTest("Only meant to test GPU_TYPE")
if not self.package_cpp_only: self.check_package_cpp_only()
raise unittest.SkipTest("Only meant to test cpp package")
if shutil.which("cmake") is None:
raise unittest.SkipTest("cmake is not available")
if shutil.which("make") is None:
raise unittest.SkipTest("make is not available")
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -293,31 +325,12 @@ class TestAOTInductorPackage(TestCase):
"aot_inductor.emit_multi_arch_kernel": True, "aot_inductor.emit_multi_arch_kernel": True,
"aot_inductor.embed_kernel_binary": True, "aot_inductor.embed_kernel_binary": True,
} }
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=options
)
with ( with (
tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as tmp_dir,
zipfile.ZipFile(package_path, "r") as zip_ref,
): ):
filenames = zip_ref.namelist() build_path, _ = self.cmake_compile(
prefix = filenames[0].split("/")[0] model, example_inputs, options, tmp_dir
zip_ref.extractall(tmp_dir)
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
self.assertTrue(tmp_path.exists())
# Create a build directory to run cmake
build_path = tmp_path / "build"
build_path.mkdir()
custom_env = os.environ.copy()
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
subprocess.run(
["cmake", ".."],
cwd=build_path,
env=custom_env,
) )
subprocess.run(["make"], cwd=build_path)
# Check if the .so file was build successfully # Check if the .so file was build successfully
so_path = build_path / "libaoti_model.so" so_path = build_path / "libaoti_model.so"
self.assertTrue(so_path.exists()) self.assertTrue(so_path.exists())
@ -325,6 +338,137 @@ class TestAOTInductorPackage(TestCase):
actual = optimized(*example_inputs) actual = optimized(*example_inputs)
self.assertTrue(torch.allclose(actual, expected)) self.assertTrue(torch.allclose(actual, expected))
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfXpu # build system may be different
def test_compile_after_package_static(self):
# compile_standalone will set package_cpp_only=True
self.check_package_cpp_only()
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
with torch.no_grad():
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model().to(device=self.device)
# Test compilation when no name is passed in
options = {
"aot_inductor.compile_standalone": True,
}
with (
tempfile.TemporaryDirectory() as tmp_dir,
):
build_path, _ = self.cmake_compile(
model, example_inputs, options, tmp_dir
)
# Check if the .a file was build successfully
a_path = build_path / "libaoti_model.a"
self.assertTrue(a_path.exists())
# Test compilation when model name is passed in
options = {
"aot_inductor.compile_standalone": True,
"aot_inductor.model_name_for_generated_files": "linear",
}
with (
tempfile.TemporaryDirectory() as tmp_dir,
):
build_path, _ = self.cmake_compile(
model, example_inputs, options, tmp_dir
)
# Check if the .a file was build successfully
a_path = build_path / "liblinear.a"
self.assertTrue(a_path.exists())
# test invalid model name
options = {
"aot_inductor.compile_standalone": True,
"aot_inductor.model_name_for_generated_files": "linear/linear",
}
with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
self.cmake_compile(model, example_inputs, options, "")
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary
def test_run_static_linkage_model(self):
self.check_package_cpp_only()
class Model1(torch.nn.Module):
def forward(self, x, y):
return x + y
class Model2(torch.nn.Module):
def forward(self, x, y):
return x - y
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model1 = Model1().to(self.device)
model2 = Model2().to(self.device)
models = [model1, model2]
i = 0
model_names = ["Plus", "Minus"]
with (
tempfile.TemporaryDirectory() as tmp_dir,
):
for i in range(2):
model = models[i]
# TODO: should be done through _ExportPackage
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(
ep,
inductor_configs={
"aot_inductor.compile_standalone": True,
"always_keep_tensor_constants": True,
"aot_inductor.model_name_for_generated_files": model_names[i],
},
)
with (
zipfile.ZipFile(package_path, "r") as zip_ref,
):
zip_ref.extractall(tmp_dir)
file_str = get_static_linkage_main_cpp_file()
with open(Path(tmp_dir) / "main.cpp", "w") as f:
f.write(file_str)
if self.device == GPU_TYPE:
cmake_file_str = get_static_linkage_makelist_file_cuda()
else:
cmake_file_str = get_static_linkage_makelist_file_cpu()
with open(Path(tmp_dir) / "CMakeLists.txt", "w") as f:
f.write(cmake_file_str)
build_path = Path(tmp_dir) / "build"
build_path.mkdir()
custom_env = os.environ.copy()
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
subprocess.run(
["cmake", ".."],
cwd=build_path,
env=custom_env,
)
subprocess.run(["make"], cwd=build_path, check=True)
subprocess.run(
["./main", f"{tmp_dir}/", self.device], cwd=build_path, check=True
)
def test_metadata(self): def test_metadata(self):
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -0,0 +1,157 @@
# Owner(s): ["module: inductor"]
from torch.testing._internal.common_utils import run_tests
def get_static_linkage_main_cpp_file():
return """
#include <dlfcn.h>
#include <iostream>
#include <memory>
#include <torch/torch.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
// Include the AOTInductor headers
#include "Minus.wrapper/data/aotinductor/model/Minus.h"
#include "Plus.wrapper/data/aotinductor/model/Plus.h"
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
using torch::aot_inductor::AOTInductorModelMinus;
using torch::aot_inductor::AOTInductorModelPlus;
using torch::aot_inductor::ConstantHandle;
using torch::aot_inductor::ConstantMap;
int main(int argc, char* argv[]) {
if (argc < 2) {
std::cerr
<< "Usage: ./main <path> <device>"
<< std::endl;
return 1;
}
std::string path = argv[1];
std::string device_str = argv[2];
try {
torch::Device device(device_str);
// Create two input tensors (10x10)
auto tensor1 = torch::ones({10, 10}, device);
auto tensor2 = torch::ones({10, 10}, device);
// Create two input tensors (10x10)
auto tensor3 = torch::ones({10, 10}, device);
auto tensor4 = torch::ones({10, 10}, device);
std::vector<at::Tensor> input_tensors = {tensor1, tensor2};
std::vector<at::Tensor> input_tensors2 = {tensor3, tensor4};
// Create array of input handles
auto input_handles1 =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
input_tensors);
auto input_handles2 =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
input_tensors2);
// Create array for output handle
AtenTensorHandle output_handle1;
AtenTensorHandle output_handle2;
auto constants_map = std::make_shared<ConstantMap>();
auto constants_array = std::make_shared<std::vector<ConstantHandle>>();
auto model1 = AOTInductorModelPlus::Create(
constants_map, constants_array, device_str,
path + "Plus.wrapper/data/"
"aotinductor/model/");
model1->load_constants();
auto constants_map2 = std::make_shared<ConstantMap>();
auto constants_array2 = std::make_shared<std::vector<ConstantHandle>>();
auto model2 = AOTInductorModelMinus::Create(
constants_map2, constants_array2, device_str,
path + "Minus.wrapper/data/"
"aotinductor/model/");
model2->load_constants();
// Run the model
torch::aot_inductor::DeviceStreamType stream1 = nullptr;
torch::aot_inductor::DeviceStreamType stream2 = nullptr;
model1->run(&input_handles1[0], &output_handle1, stream1, nullptr);
model2->run(&input_handles2[0], &output_handle2, stream2, nullptr);
// Convert output handle to tensor
auto output_tensor1 =
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
&output_handle1, 1);
auto output_tensor2 =
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
&output_handle2, 1);
if (!(torch::all(output_tensor1[0] == 2).item<bool>())){
std::cout << "Wrong Output for Plus Model: " << output_tensor1 << std::endl;
throw std::runtime_error("Tensor does not contain only the expected value 2.");
}
if (!(torch::all(output_tensor2[0] == 0).item<bool>())){
std::cout << "Wrong Output for Minus Model: " << output_tensor1 << std::endl;
throw std::runtime_error("Tensor does not contain only the expected value 0.");
}
return 0;
} catch (const std::exception &e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
}
"""
def get_static_linkage_makelist_file_cuda():
return """
cmake_minimum_required(VERSION 3.10)
project(TestProject)
set(CMAKE_CXX_STANDARD 17)
find_package(Torch REQUIRED)
find_package(CUDA REQUIRED)
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
# Create executable
add_executable(main main.cpp)
target_compile_definitions(main PRIVATE USE_CUDA)
target_link_libraries(main PRIVATE torch cuda
${CUDA_LIBRARIES}
Plus
Minus)
"""
def get_static_linkage_makelist_file_cpu():
return """
cmake_minimum_required(VERSION 3.10)
project(TestProject)
set(CMAKE_CXX_STANDARD 17)
find_package(Torch REQUIRED)
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
# Create executable
add_executable(main main.cpp)
target_link_libraries(main PRIVATE torch
Plus
Minus)
"""
if __name__ == "__main__":
run_tests()

View File

@ -1660,6 +1660,12 @@ class AotCodeCompiler:
wrapper_code = "\n".join((wrapper_code, kernel_code)) wrapper_code = "\n".join((wrapper_code, kernel_code))
kernel_code = "" kernel_code = ""
from .utils import aoti_model_name_from_config
model_class_name = ""
if config.aot_inductor.compile_standalone:
model_class_name = aoti_model_name_from_config()
wrapper_key, wrapper_path = write( wrapper_key, wrapper_path = write(
wrapper_code, wrapper_code,
"wrapper.cpp", "wrapper.cpp",
@ -1679,6 +1685,36 @@ class AotCodeCompiler:
key=config.aot_inductor.model_name_for_generated_files, key=config.aot_inductor.model_name_for_generated_files,
) )
header_code = ""
header_path = ""
if config.aot_inductor.compile_standalone:
# to link statically, we also need a header file
with open(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"csrc",
"inductor",
"aoti_runtime",
"model.h",
)
) as f:
class_name = f"AOTInductorModel{model_class_name}"
header_code = f.read()
# we replace like this to avoid replacing
# AOTInductorModelBase and AOTInductorModelKernelsBase
header_code = (
header_code.replace("<AOTInductorModel>", f"<{class_name}>")
.replace("AOTInductorModel(", f"{class_name}(")
.replace("AOTInductorModel :", f"{class_name} :")
)
_, header_path = write(
header_code,
"h",
specified_dir=specified_output_path,
key=f"{model_class_name}",
)
# Log the AOTInductor wrapper and kernel code, if needed. # Log the AOTInductor wrapper and kernel code, if needed.
with tempfile.NamedTemporaryFile("w+") as t: with tempfile.NamedTemporaryFile("w+") as t:
t.writelines((wrapper_code, "\n", kernel_code, "\n")) t.writelines((wrapper_code, "\n", kernel_code, "\n"))
@ -1689,6 +1725,8 @@ class AotCodeCompiler:
generated_files.append(wrapper_path) generated_files.append(wrapper_path)
if not config.aot_inductor.package_cpp_only: if not config.aot_inductor.package_cpp_only:
generated_files.append(kernel_path) generated_files.append(kernel_path)
if config.aot_inductor.compile_standalone:
generated_files.append(header_path)
output_code_log.info("Wrapper code written to: %s", wrapper_path) output_code_log.info("Wrapper code written to: %s", wrapper_path)
output_code_log.info("Kernel code written to: %s", kernel_path) output_code_log.info("Kernel code written to: %s", kernel_path)
@ -1710,6 +1748,17 @@ class AotCodeCompiler:
}, },
payload_fn=lambda: kernel_code, payload_fn=lambda: kernel_code,
) )
if config.aot_inductor.compile_standalone:
output_code_log.info("Header code written to: %s", header_path)
trace_structured(
"graph_dump",
lambda: {
"name": "inductor_aot_header_code",
"type": "cpp",
"filename": header_path,
},
payload_fn=lambda: header_code,
)
# We use a file lock below to protect FS operations. The lock file # We use a file lock below to protect FS operations. The lock file
# is scoped to the 'key', so make sure the consts_s is protected # is scoped to the 'key', so make sure the consts_s is protected

View File

@ -22,7 +22,13 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import symbol_is_type, SymT from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config, ir from .. import config, ir
from ..utils import _align, DeferredLineBase, LineContext, normalize_name from ..utils import (
_align,
aoti_model_name_from_config,
DeferredLineBase,
LineContext,
normalize_name,
)
from ..virtualized import V from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import get_device_op_overrides, IndentedBuffer, Kernel from .common import get_device_op_overrides, IndentedBuffer, Kernel
@ -58,6 +64,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.device = "cpu" self.device = "cpu"
# must be initialized prior to calling super().__init__() # must be initialized prior to calling super().__init__()
self.included_devices: OrderedSet[str] = OrderedSet() self.included_devices: OrderedSet[str] = OrderedSet()
self.model_class_name_suffix = ""
if config.aot_inductor.compile_standalone:
self.model_class_name_suffix = aoti_model_name_from_config()
self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}"
super().__init__() super().__init__()
self.declare = "auto " self.declare = "auto "
self.declare_maybe_reference = "decltype(auto) " self.declare_maybe_reference = "decltype(auto) "
@ -208,10 +218,16 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.add_device_include(self.device) self.add_device_include(self.device)
if V.graph.aot_mode: if V.graph.aot_mode:
with open( if not config.aot_inductor.compile_standalone:
os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp") with open(
) as f: os.path.join(
self.header.splice(f.read()) os.path.dirname(__file__), "aoti_runtime", "interface.cpp"
)
) as f:
self.header.splice(f.read())
else:
# we produce a separate model header for each model in static linkage
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
self.header.splice("\n") self.header.splice("\n")
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
@ -508,12 +524,12 @@ class CppWrapperCpu(PythonWrapperCodegen):
if V.graph.is_const_graph: if V.graph.is_const_graph:
self.prefix.splice( self.prefix.splice(
""" f"""
void AOTInductorModel::_const_run_impl( void {self.aoti_model_class_name}::_const_run_impl(
std::vector<AtenTensorHandle>& output_handles, std::vector<AtenTensorHandle>& output_handles,
DeviceStreamType stream, DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor AOTIProxyExecutorHandle proxy_executor
) { ) {{
""" """
) )
else: else:
@ -521,18 +537,18 @@ class CppWrapperCpu(PythonWrapperCodegen):
# If we do not split the constant graph, we'll just create # If we do not split the constant graph, we'll just create
# an empty implementation when wrapping the main module. # an empty implementation when wrapping the main module.
self.prefix.splice( self.prefix.splice(
""" f"""
void AOTInductorModel::_const_run_impl( void {self.aoti_model_class_name}::_const_run_impl(
std::vector<AtenTensorHandle>& output_handles, std::vector<AtenTensorHandle>& output_handles,
DeviceStreamType stream, DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor AOTIProxyExecutorHandle proxy_executor
) {} ) {{}}
""" """
) )
run_impl_proto = """ run_impl_proto = f"""
void AOTInductorModel::run_impl( void {self.aoti_model_class_name}::run_impl(
AtenTensorHandle* AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed // are stolen; the array itself is borrowed
@ -542,7 +558,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
// borrowed // borrowed
DeviceStreamType stream, DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor AOTIProxyExecutorHandle proxy_executor
) { ) {{
__check_inputs_outputs(input_handles, output_handles); __check_inputs_outputs(input_handles, output_handles);
""" """
@ -734,7 +750,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
) )
self.prefix.splice( self.prefix.splice(
f""" f"""
AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map, {self.aoti_model_class_name}::{self.aoti_model_class_name}(std::shared_ptr<ConstantMap> constants_map,
std::shared_ptr<std::vector<ConstantHandle>> constants_array, std::shared_ptr<std::vector<ConstantHandle>> constants_array,
const std::string& device_str, const std::string& device_str,
std::optional<std::string> cubin_dir) std::optional<std::string> cubin_dir)
@ -891,12 +907,12 @@ class CppWrapperCpu(PythonWrapperCodegen):
""" """
self.prefix.splice( self.prefix.splice(
""" f"""
std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl( std::unordered_map<std::string, AtenTensorHandle> {self.aoti_model_class_name}::const_run_impl(
DeviceStreamType stream, DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor, AOTIProxyExecutorHandle proxy_executor,
bool initialization bool initialization
) { ) {{
""" """
) )
if not config.aot_inductor.use_runtime_constant_folding: if not config.aot_inductor.use_runtime_constant_folding:
@ -1079,7 +1095,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
def generate_before_suffix(self, result): def generate_before_suffix(self, result):
if not V.graph.is_const_graph: if not V.graph.is_const_graph:
if V.graph.aot_mode: if V.graph.aot_mode:
result.writeline("} // AOTInductorModel::run_impl") result.writeline(f"}} // {self.aoti_model_class_name}::run_impl")
else: else:
result.writeline("} // inductor_entry_impl") result.writeline("} // inductor_entry_impl")
@ -1087,7 +1103,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
"""Generates the end of the code block, and any code needed to call it.""" """Generates the end of the code block, and any code needed to call it."""
if V.graph.aot_mode: if V.graph.aot_mode:
if V.graph.is_const_graph: if V.graph.is_const_graph:
result.writeline("} // AOTInductorModel::_const_run_impl") result.writeline(f"}} // {self.aoti_model_class_name}::_const_run_impl")
else: else:
result.writeline("} // namespace torch::aot_inductor\n\n\n") result.writeline("} // namespace torch::aot_inductor\n\n\n")
return return

View File

@ -2407,6 +2407,10 @@ def compile_fx(
) )
if V.aot_compilation: if V.aot_compilation:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()
with functorch_config.patch(unlift_effect_tokens=True): with functorch_config.patch(unlift_effect_tokens=True):
gm, graph_signature = aot_export_module( gm, graph_signature = aot_export_module(
model_, model_,

View File

@ -1421,6 +1421,13 @@ class aot_inductor:
# If not None, the generated files with use this name in file stem. # If not None, the generated files with use this name in file stem.
# If None, we will use a hash to name files. # If None, we will use a hash to name files.
#
# If package_cpp_only, this name is also used for the target name in CMakelists.txt
# The default target name is "aoti_model"
#
# If compile_standalone, the aoti model class name is f"AOTInductorModel{name}"
#
# This name can only contain letters, numbers, and underscores.
model_name_for_generated_files: Optional[str] = None model_name_for_generated_files: Optional[str] = None
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict

View File

@ -28,6 +28,7 @@ from torch._dynamo.utils import dynamo_timed
from torch._inductor import config, exc from torch._inductor import config, exc
from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA
from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import aoti_model_name_from_config
from torch.torch_version import TorchVersion from torch.torch_version import TorchVersion
@ -1505,6 +1506,7 @@ class CppBuilder:
self._aot_mode: bool = False self._aot_mode: bool = False
self._name = name self._name = name
self._target_name = aoti_model_name_from_config()
# Code start here, initial self internal variables firstly. # Code start here, initial self internal variables firstly.
self._build_option = BuildOption self._build_option = BuildOption
@ -1730,25 +1732,29 @@ class CppBuilder:
""" """
definitions = " ".join(self._build_option.get_definitions()) definitions = " ".join(self._build_option.get_definitions())
target_library_type = (
"STATIC" if config.aot_inductor.compile_standalone else "SHARED"
)
contents = textwrap.dedent( contents = textwrap.dedent(
f""" f"""
cmake_minimum_required(VERSION 3.27 FATAL_ERROR) cmake_minimum_required(VERSION 3.27 FATAL_ERROR)
project(aoti_model LANGUAGES CXX) project({self._target_name} LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
# May need to point CMAKE_PREFIX_PATH to the right torch location # May need to point CMAKE_PREFIX_PATH to the right torch location
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
# Set a shared library target # Set a shared library target
add_library(aoti_model SHARED) add_library({self._target_name} {target_library_type})
# Add macro definitions # Add macro definitions
target_compile_definitions(aoti_model PRIVATE {definitions}) target_compile_definitions({self._target_name} PRIVATE {definitions})
# Add compile flags # Add compile flags
target_compile_options(aoti_model PRIVATE {self._cflags_args}) target_compile_options({self._target_name} PRIVATE {self._cflags_args})
# Backend specific flags # Backend specific flags
target_compile_options(aoti_model PRIVATE {self._passthrough_parameters_args} -c) target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
""" """
) )
@ -1823,7 +1829,7 @@ class CppBuilder:
# Remove the directory part of file_path # Remove the directory part of file_path
src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name
with open(cmake_path, "a") as f: with open(cmake_path, "a") as f:
f.write(f"target_sources(aoti_model PRIVATE {src_path})\n") f.write(f"target_sources({self._target_name} PRIVATE {src_path})\n")
def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None: def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None:
# TODO: make this work beyond CUDA # TODO: make this work beyond CUDA
@ -1837,9 +1843,9 @@ class CppBuilder:
""" """
) )
f.write(contents) f.write(contents)
f.write("add_dependencies(aoti_model ${KERNEL_TARGETS})\n") f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n")
f.write( f.write(
"target_link_libraries(aoti_model PRIVATE ${KERNEL_OBJECT_FILES})\n" f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n"
) )
def save_link_cmd_to_cmake(self, cmake_path: str) -> None: def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
@ -1848,10 +1854,10 @@ class CppBuilder:
contents = textwrap.dedent( contents = textwrap.dedent(
f""" f"""
# Add linker flags # Add linker flags
target_link_options(aoti_model PRIVATE {lflags}) target_link_options({self._target_name} PRIVATE {lflags})
# Add libraries # Add libraries
target_link_libraries(aoti_model PRIVATE {libs}) target_link_libraries({self._target_name} PRIVATE {libs})
""" """
) )

View File

@ -3308,3 +3308,38 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
"Please set aot_inductor.package_cpp_only=True in your inductor config." "Please set aot_inductor.package_cpp_only=True in your inductor config."
) )
return config_patches return config_patches
def is_valid_aoti_model_name() -> bool:
"""
Validates if a model name is suitable for use in code generation.
"""
from torch._inductor import config
model_name = config.aot_inductor.model_name_for_generated_files
if model_name is None:
return True
if not isinstance(model_name, str):
raise ValueError("Invalid AOTI model name: Model name must be a string")
if model_name == "":
return True
# Can only contain alphanumeric characters and underscores
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", model_name):
raise ValueError(
"Invalid AOTI model name: Model name can only contain letters, numbers, and underscores"
)
return True
def aoti_model_name_from_config() -> str:
from torch._inductor import config
model_name = config.aot_inductor.model_name_for_generated_files
model_name = "aoti_model" if model_name is None else model_name
return model_name