mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] codegen for static linkage (#157129)
Design doc: https://docs.google.com/document/d/1ncV7RpJ8xDwy8-_aCBfvZmpTTL824C-aoNPBLLVkOHM/edit?tab=t.0 (internal) - Add codegen for static linkage - refactor test code for test_compile_after_package tests For now, the following options must be used together with `"aot_inductor.compile_standalone": True`. "aot_inductor.package_cpp_only": True, Will change `"aot_inductor.package_cpp_only"` to be automatically set to True in followup PR. ``` python test/inductor/test_aot_inductor_package.py -k test_compile_after_package python test/inductor/test_aot_inductor_package.py -k test_run_static_linkage_model ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157129 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
9bdf87e891
commit
4781d72faa
@ -231,7 +231,8 @@ include_patterns = [
|
|||||||
'c10/**/*.cpp',
|
'c10/**/*.cpp',
|
||||||
'c10/**/*.h',
|
'c10/**/*.h',
|
||||||
'torch/*.h',
|
'torch/*.h',
|
||||||
'torch/_inductor/codegen/aoti_runtime/interface.cpp',
|
'torch/_inductor/codegen/aoti_runtime/*.h',
|
||||||
|
'torch/_inductor/codegen/aoti_runtime/*.cpp',
|
||||||
'torch/csrc/*.h',
|
'torch/csrc/*.h',
|
||||||
'torch/csrc/*.cpp',
|
'torch/csrc/*.cpp',
|
||||||
'torch/csrc/**/*.h',
|
'torch/csrc/**/*.h',
|
||||||
|
2
setup.py
2
setup.py
@ -1310,7 +1310,9 @@ def main() -> None:
|
|||||||
"include/**/*.hpp",
|
"include/**/*.hpp",
|
||||||
"include/*.cuh",
|
"include/*.cuh",
|
||||||
"include/**/*.cuh",
|
"include/**/*.cuh",
|
||||||
|
"csrc/inductor/aoti_runtime/model.h",
|
||||||
"_inductor/codegen/*.h",
|
"_inductor/codegen/*.h",
|
||||||
|
"_inductor/codegen/aoti_runtime/*.h",
|
||||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||||
"_inductor/script.ld",
|
"_inductor/script.ld",
|
||||||
"_export/serde/*.yaml",
|
"_export/serde/*.yaml",
|
||||||
|
@ -30,6 +30,20 @@ from torch.testing._internal.common_utils import (
|
|||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from test_static_linkage_utils import (
|
||||||
|
get_static_linkage_main_cpp_file,
|
||||||
|
get_static_linkage_makelist_file_cpu,
|
||||||
|
get_static_linkage_makelist_file_cuda,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from .test_static_linkage_utils import (
|
||||||
|
get_static_linkage_main_cpp_file,
|
||||||
|
get_static_linkage_makelist_file_cpu,
|
||||||
|
get_static_linkage_makelist_file_cuda,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def skipif(predicate: Callable[[str, bool], bool], reason: str):
|
def skipif(predicate: Callable[[str, bool], bool], reason: str):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
@ -126,6 +140,54 @@ class TestAOTInductorPackage(TestCase):
|
|||||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||||
return compiled_model
|
return compiled_model
|
||||||
|
|
||||||
|
def check_package_cpp_only(self: TestCase) -> None:
|
||||||
|
"""
|
||||||
|
Check if cmake and make are available.
|
||||||
|
Skip self.package_cpp_only=False tests
|
||||||
|
"""
|
||||||
|
if not self.package_cpp_only:
|
||||||
|
raise unittest.SkipTest("Only meant to test cpp package")
|
||||||
|
if shutil.which("cmake") is None:
|
||||||
|
raise unittest.SkipTest("cmake is not available")
|
||||||
|
if shutil.which("make") is None:
|
||||||
|
raise unittest.SkipTest("make is not available")
|
||||||
|
|
||||||
|
def cmake_compile(self, model, example_inputs, options, tmp_dir):
|
||||||
|
"""
|
||||||
|
Exports model, compiles it using AOTInductor, extracts the
|
||||||
|
generated files to tmp_dir, and builds the C++ code using CMake and Make.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- build_path (Path): Path to the CMake build directory containing the compiled binary.
|
||||||
|
- tmp_path (Path): Path to the extracted model source directory.
|
||||||
|
"""
|
||||||
|
ep = torch.export.export(model, example_inputs)
|
||||||
|
package_path = torch._inductor.aoti_compile_and_package(
|
||||||
|
ep, inductor_configs=options
|
||||||
|
)
|
||||||
|
with (
|
||||||
|
zipfile.ZipFile(package_path, "r") as zip_ref,
|
||||||
|
):
|
||||||
|
filenames = zip_ref.namelist()
|
||||||
|
prefix = filenames[0].split("/")[0]
|
||||||
|
zip_ref.extractall(tmp_dir)
|
||||||
|
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
|
||||||
|
self.assertTrue(tmp_path.exists())
|
||||||
|
# Create a build directory to run cmake
|
||||||
|
build_path = tmp_path / "build"
|
||||||
|
self.assertTrue(not build_path.exists())
|
||||||
|
build_path.mkdir()
|
||||||
|
custom_env = os.environ.copy()
|
||||||
|
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
|
||||||
|
subprocess.run(
|
||||||
|
["cmake", ".."],
|
||||||
|
cwd=build_path,
|
||||||
|
env=custom_env,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
subprocess.run(["make"], cwd=build_path, check=True)
|
||||||
|
return build_path, tmp_path
|
||||||
|
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
@ -189,12 +251,7 @@ class TestAOTInductorPackage(TestCase):
|
|||||||
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||||
@skipIfXpu # build system may be different
|
@skipIfXpu # build system may be different
|
||||||
def test_compile_after_package(self):
|
def test_compile_after_package(self):
|
||||||
if not self.package_cpp_only:
|
self.check_package_cpp_only()
|
||||||
raise unittest.SkipTest("Only meant to test cpp package")
|
|
||||||
if shutil.which("cmake") is None:
|
|
||||||
raise unittest.SkipTest("cmake is not available")
|
|
||||||
if shutil.which("make") is None:
|
|
||||||
raise unittest.SkipTest("make is not available")
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -217,39 +274,19 @@ class TestAOTInductorPackage(TestCase):
|
|||||||
# Require kernels to be compiled into .o files
|
# Require kernels to be compiled into .o files
|
||||||
"aot_inductor.embed_kernel_binary": True,
|
"aot_inductor.embed_kernel_binary": True,
|
||||||
}
|
}
|
||||||
ep = torch.export.export(model, example_inputs, strict=True)
|
|
||||||
package_path = torch._inductor.aoti_compile_and_package(
|
|
||||||
ep, inductor_configs=options
|
|
||||||
)
|
|
||||||
with (
|
with (
|
||||||
tempfile.TemporaryDirectory() as tmp_dir,
|
tempfile.TemporaryDirectory() as tmp_dir,
|
||||||
zipfile.ZipFile(package_path, "r") as zip_ref,
|
|
||||||
):
|
):
|
||||||
filenames = zip_ref.namelist()
|
build_path, tmp_path = self.cmake_compile(
|
||||||
prefix = filenames[0].split("/")[0]
|
model, example_inputs, options, tmp_dir
|
||||||
zip_ref.extractall(tmp_dir)
|
)
|
||||||
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
|
|
||||||
self.assertTrue(tmp_path.exists())
|
|
||||||
if self.device == GPU_TYPE:
|
if self.device == GPU_TYPE:
|
||||||
kernel_bin = get_kernel_bin_format(self.device)
|
kernel_bin = get_kernel_bin_format(self.device)
|
||||||
self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}")))
|
self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}")))
|
||||||
# Check if .cubin.o files exist and use unique kernel names
|
# Check if .cubin.o files exist and use unique kernel names
|
||||||
self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o")))
|
self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o")))
|
||||||
|
|
||||||
build_path = tmp_path / "build"
|
|
||||||
self.assertTrue(not build_path.exists())
|
|
||||||
|
|
||||||
# Create a build directory to run cmake
|
|
||||||
build_path.mkdir()
|
|
||||||
custom_env = os.environ.copy()
|
|
||||||
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
|
|
||||||
subprocess.run(
|
|
||||||
["cmake", ".."],
|
|
||||||
cwd=build_path,
|
|
||||||
env=custom_env,
|
|
||||||
)
|
|
||||||
subprocess.run(["make"], cwd=build_path)
|
|
||||||
|
|
||||||
# Check if the .so file was build successfully
|
# Check if the .so file was build successfully
|
||||||
so_path = build_path / "libaoti_model.so"
|
so_path = build_path / "libaoti_model.so"
|
||||||
self.assertTrue(so_path.exists())
|
self.assertTrue(so_path.exists())
|
||||||
@ -263,12 +300,7 @@ class TestAOTInductorPackage(TestCase):
|
|||||||
def test_compile_after_package_multi_arch(self):
|
def test_compile_after_package_multi_arch(self):
|
||||||
if self.device != GPU_TYPE:
|
if self.device != GPU_TYPE:
|
||||||
raise unittest.SkipTest("Only meant to test GPU_TYPE")
|
raise unittest.SkipTest("Only meant to test GPU_TYPE")
|
||||||
if not self.package_cpp_only:
|
self.check_package_cpp_only()
|
||||||
raise unittest.SkipTest("Only meant to test cpp package")
|
|
||||||
if shutil.which("cmake") is None:
|
|
||||||
raise unittest.SkipTest("cmake is not available")
|
|
||||||
if shutil.which("make") is None:
|
|
||||||
raise unittest.SkipTest("make is not available")
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -293,31 +325,12 @@ class TestAOTInductorPackage(TestCase):
|
|||||||
"aot_inductor.emit_multi_arch_kernel": True,
|
"aot_inductor.emit_multi_arch_kernel": True,
|
||||||
"aot_inductor.embed_kernel_binary": True,
|
"aot_inductor.embed_kernel_binary": True,
|
||||||
}
|
}
|
||||||
ep = torch.export.export(model, example_inputs)
|
|
||||||
package_path = torch._inductor.aoti_compile_and_package(
|
|
||||||
ep, inductor_configs=options
|
|
||||||
)
|
|
||||||
with (
|
with (
|
||||||
tempfile.TemporaryDirectory() as tmp_dir,
|
tempfile.TemporaryDirectory() as tmp_dir,
|
||||||
zipfile.ZipFile(package_path, "r") as zip_ref,
|
|
||||||
):
|
):
|
||||||
filenames = zip_ref.namelist()
|
build_path, _ = self.cmake_compile(
|
||||||
prefix = filenames[0].split("/")[0]
|
model, example_inputs, options, tmp_dir
|
||||||
zip_ref.extractall(tmp_dir)
|
|
||||||
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
|
|
||||||
self.assertTrue(tmp_path.exists())
|
|
||||||
# Create a build directory to run cmake
|
|
||||||
build_path = tmp_path / "build"
|
|
||||||
build_path.mkdir()
|
|
||||||
custom_env = os.environ.copy()
|
|
||||||
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
|
|
||||||
subprocess.run(
|
|
||||||
["cmake", ".."],
|
|
||||||
cwd=build_path,
|
|
||||||
env=custom_env,
|
|
||||||
)
|
)
|
||||||
subprocess.run(["make"], cwd=build_path)
|
|
||||||
|
|
||||||
# Check if the .so file was build successfully
|
# Check if the .so file was build successfully
|
||||||
so_path = build_path / "libaoti_model.so"
|
so_path = build_path / "libaoti_model.so"
|
||||||
self.assertTrue(so_path.exists())
|
self.assertTrue(so_path.exists())
|
||||||
@ -325,6 +338,137 @@ class TestAOTInductorPackage(TestCase):
|
|||||||
actual = optimized(*example_inputs)
|
actual = optimized(*example_inputs)
|
||||||
self.assertTrue(torch.allclose(actual, expected))
|
self.assertTrue(torch.allclose(actual, expected))
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||||
|
@skipIfXpu # build system may be different
|
||||||
|
def test_compile_after_package_static(self):
|
||||||
|
# compile_standalone will set package_cpp_only=True
|
||||||
|
self.check_package_cpp_only()
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(10, 10)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return x + self.linear(y)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
example_inputs = (
|
||||||
|
torch.randn(10, 10, device=self.device),
|
||||||
|
torch.randn(10, 10, device=self.device),
|
||||||
|
)
|
||||||
|
model = Model().to(device=self.device)
|
||||||
|
|
||||||
|
# Test compilation when no name is passed in
|
||||||
|
options = {
|
||||||
|
"aot_inductor.compile_standalone": True,
|
||||||
|
}
|
||||||
|
with (
|
||||||
|
tempfile.TemporaryDirectory() as tmp_dir,
|
||||||
|
):
|
||||||
|
build_path, _ = self.cmake_compile(
|
||||||
|
model, example_inputs, options, tmp_dir
|
||||||
|
)
|
||||||
|
# Check if the .a file was build successfully
|
||||||
|
a_path = build_path / "libaoti_model.a"
|
||||||
|
self.assertTrue(a_path.exists())
|
||||||
|
|
||||||
|
# Test compilation when model name is passed in
|
||||||
|
options = {
|
||||||
|
"aot_inductor.compile_standalone": True,
|
||||||
|
"aot_inductor.model_name_for_generated_files": "linear",
|
||||||
|
}
|
||||||
|
with (
|
||||||
|
tempfile.TemporaryDirectory() as tmp_dir,
|
||||||
|
):
|
||||||
|
build_path, _ = self.cmake_compile(
|
||||||
|
model, example_inputs, options, tmp_dir
|
||||||
|
)
|
||||||
|
# Check if the .a file was build successfully
|
||||||
|
a_path = build_path / "liblinear.a"
|
||||||
|
self.assertTrue(a_path.exists())
|
||||||
|
|
||||||
|
# test invalid model name
|
||||||
|
options = {
|
||||||
|
"aot_inductor.compile_standalone": True,
|
||||||
|
"aot_inductor.model_name_for_generated_files": "linear/linear",
|
||||||
|
}
|
||||||
|
with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
|
||||||
|
self.cmake_compile(model, example_inputs, options, "")
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||||
|
@skipIfRocm # doesn't support multi-arch binary
|
||||||
|
@skipIfXpu # doesn't support multi-arch binary
|
||||||
|
def test_run_static_linkage_model(self):
|
||||||
|
self.check_package_cpp_only()
|
||||||
|
|
||||||
|
class Model1(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
class Model2(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
return x - y
|
||||||
|
|
||||||
|
example_inputs = (
|
||||||
|
torch.randn(10, 10, device=self.device),
|
||||||
|
torch.randn(10, 10, device=self.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
model1 = Model1().to(self.device)
|
||||||
|
model2 = Model2().to(self.device)
|
||||||
|
|
||||||
|
models = [model1, model2]
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
model_names = ["Plus", "Minus"]
|
||||||
|
with (
|
||||||
|
tempfile.TemporaryDirectory() as tmp_dir,
|
||||||
|
):
|
||||||
|
for i in range(2):
|
||||||
|
model = models[i]
|
||||||
|
# TODO: should be done through _ExportPackage
|
||||||
|
ep = torch.export.export(model, example_inputs)
|
||||||
|
|
||||||
|
package_path = torch._inductor.aoti_compile_and_package(
|
||||||
|
ep,
|
||||||
|
inductor_configs={
|
||||||
|
"aot_inductor.compile_standalone": True,
|
||||||
|
"always_keep_tensor_constants": True,
|
||||||
|
"aot_inductor.model_name_for_generated_files": model_names[i],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with (
|
||||||
|
zipfile.ZipFile(package_path, "r") as zip_ref,
|
||||||
|
):
|
||||||
|
zip_ref.extractall(tmp_dir)
|
||||||
|
|
||||||
|
file_str = get_static_linkage_main_cpp_file()
|
||||||
|
with open(Path(tmp_dir) / "main.cpp", "w") as f:
|
||||||
|
f.write(file_str)
|
||||||
|
|
||||||
|
if self.device == GPU_TYPE:
|
||||||
|
cmake_file_str = get_static_linkage_makelist_file_cuda()
|
||||||
|
else:
|
||||||
|
cmake_file_str = get_static_linkage_makelist_file_cpu()
|
||||||
|
with open(Path(tmp_dir) / "CMakeLists.txt", "w") as f:
|
||||||
|
f.write(cmake_file_str)
|
||||||
|
|
||||||
|
build_path = Path(tmp_dir) / "build"
|
||||||
|
build_path.mkdir()
|
||||||
|
custom_env = os.environ.copy()
|
||||||
|
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
|
||||||
|
subprocess.run(
|
||||||
|
["cmake", ".."],
|
||||||
|
cwd=build_path,
|
||||||
|
env=custom_env,
|
||||||
|
)
|
||||||
|
|
||||||
|
subprocess.run(["make"], cwd=build_path, check=True)
|
||||||
|
subprocess.run(
|
||||||
|
["./main", f"{tmp_dir}/", self.device], cwd=build_path, check=True
|
||||||
|
)
|
||||||
|
|
||||||
def test_metadata(self):
|
def test_metadata(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
157
test/inductor/test_static_linkage_utils.py
Normal file
157
test/inductor/test_static_linkage_utils.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# Owner(s): ["module: inductor"]
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_linkage_main_cpp_file():
|
||||||
|
return """
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
// Include the AOTInductor headers
|
||||||
|
#include "Minus.wrapper/data/aotinductor/model/Minus.h"
|
||||||
|
#include "Plus.wrapper/data/aotinductor/model/Plus.h"
|
||||||
|
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
|
||||||
|
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||||
|
|
||||||
|
using torch::aot_inductor::AOTInductorModelMinus;
|
||||||
|
using torch::aot_inductor::AOTInductorModelPlus;
|
||||||
|
using torch::aot_inductor::ConstantHandle;
|
||||||
|
using torch::aot_inductor::ConstantMap;
|
||||||
|
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
if (argc < 2) {
|
||||||
|
std::cerr
|
||||||
|
<< "Usage: ./main <path> <device>"
|
||||||
|
<< std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
std::string path = argv[1];
|
||||||
|
std::string device_str = argv[2];
|
||||||
|
try {
|
||||||
|
torch::Device device(device_str);
|
||||||
|
|
||||||
|
// Create two input tensors (10x10)
|
||||||
|
auto tensor1 = torch::ones({10, 10}, device);
|
||||||
|
auto tensor2 = torch::ones({10, 10}, device);
|
||||||
|
// Create two input tensors (10x10)
|
||||||
|
auto tensor3 = torch::ones({10, 10}, device);
|
||||||
|
auto tensor4 = torch::ones({10, 10}, device);
|
||||||
|
|
||||||
|
std::vector<at::Tensor> input_tensors = {tensor1, tensor2};
|
||||||
|
std::vector<at::Tensor> input_tensors2 = {tensor3, tensor4};
|
||||||
|
|
||||||
|
// Create array of input handles
|
||||||
|
auto input_handles1 =
|
||||||
|
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
|
||||||
|
input_tensors);
|
||||||
|
auto input_handles2 =
|
||||||
|
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
|
||||||
|
input_tensors2);
|
||||||
|
|
||||||
|
// Create array for output handle
|
||||||
|
AtenTensorHandle output_handle1;
|
||||||
|
AtenTensorHandle output_handle2;
|
||||||
|
|
||||||
|
auto constants_map = std::make_shared<ConstantMap>();
|
||||||
|
auto constants_array = std::make_shared<std::vector<ConstantHandle>>();
|
||||||
|
auto model1 = AOTInductorModelPlus::Create(
|
||||||
|
constants_map, constants_array, device_str,
|
||||||
|
path + "Plus.wrapper/data/"
|
||||||
|
"aotinductor/model/");
|
||||||
|
model1->load_constants();
|
||||||
|
|
||||||
|
auto constants_map2 = std::make_shared<ConstantMap>();
|
||||||
|
auto constants_array2 = std::make_shared<std::vector<ConstantHandle>>();
|
||||||
|
auto model2 = AOTInductorModelMinus::Create(
|
||||||
|
constants_map2, constants_array2, device_str,
|
||||||
|
path + "Minus.wrapper/data/"
|
||||||
|
"aotinductor/model/");
|
||||||
|
model2->load_constants();
|
||||||
|
|
||||||
|
// Run the model
|
||||||
|
torch::aot_inductor::DeviceStreamType stream1 = nullptr;
|
||||||
|
torch::aot_inductor::DeviceStreamType stream2 = nullptr;
|
||||||
|
model1->run(&input_handles1[0], &output_handle1, stream1, nullptr);
|
||||||
|
model2->run(&input_handles2[0], &output_handle2, stream2, nullptr);
|
||||||
|
|
||||||
|
// Convert output handle to tensor
|
||||||
|
auto output_tensor1 =
|
||||||
|
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
|
||||||
|
&output_handle1, 1);
|
||||||
|
auto output_tensor2 =
|
||||||
|
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
|
||||||
|
&output_handle2, 1);
|
||||||
|
|
||||||
|
if (!(torch::all(output_tensor1[0] == 2).item<bool>())){
|
||||||
|
std::cout << "Wrong Output for Plus Model: " << output_tensor1 << std::endl;
|
||||||
|
throw std::runtime_error("Tensor does not contain only the expected value 2.");
|
||||||
|
}
|
||||||
|
if (!(torch::all(output_tensor2[0] == 0).item<bool>())){
|
||||||
|
std::cout << "Wrong Output for Minus Model: " << output_tensor1 << std::endl;
|
||||||
|
throw std::runtime_error("Tensor does not contain only the expected value 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
std::cerr << "Error: " << e.what() << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_linkage_makelist_file_cuda():
|
||||||
|
return """
|
||||||
|
cmake_minimum_required(VERSION 3.10)
|
||||||
|
project(TestProject)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
|
find_package(Torch REQUIRED)
|
||||||
|
find_package(CUDA REQUIRED)
|
||||||
|
|
||||||
|
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
|
||||||
|
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
|
||||||
|
|
||||||
|
# Create executable
|
||||||
|
add_executable(main main.cpp)
|
||||||
|
|
||||||
|
target_compile_definitions(main PRIVATE USE_CUDA)
|
||||||
|
|
||||||
|
target_link_libraries(main PRIVATE torch cuda
|
||||||
|
${CUDA_LIBRARIES}
|
||||||
|
Plus
|
||||||
|
Minus)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_linkage_makelist_file_cpu():
|
||||||
|
return """
|
||||||
|
cmake_minimum_required(VERSION 3.10)
|
||||||
|
project(TestProject)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
|
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
|
||||||
|
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
|
||||||
|
|
||||||
|
# Create executable
|
||||||
|
add_executable(main main.cpp)
|
||||||
|
|
||||||
|
target_link_libraries(main PRIVATE torch
|
||||||
|
Plus
|
||||||
|
Minus)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -1660,6 +1660,12 @@ class AotCodeCompiler:
|
|||||||
wrapper_code = "\n".join((wrapper_code, kernel_code))
|
wrapper_code = "\n".join((wrapper_code, kernel_code))
|
||||||
kernel_code = ""
|
kernel_code = ""
|
||||||
|
|
||||||
|
from .utils import aoti_model_name_from_config
|
||||||
|
|
||||||
|
model_class_name = ""
|
||||||
|
if config.aot_inductor.compile_standalone:
|
||||||
|
model_class_name = aoti_model_name_from_config()
|
||||||
|
|
||||||
wrapper_key, wrapper_path = write(
|
wrapper_key, wrapper_path = write(
|
||||||
wrapper_code,
|
wrapper_code,
|
||||||
"wrapper.cpp",
|
"wrapper.cpp",
|
||||||
@ -1679,6 +1685,36 @@ class AotCodeCompiler:
|
|||||||
key=config.aot_inductor.model_name_for_generated_files,
|
key=config.aot_inductor.model_name_for_generated_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
header_code = ""
|
||||||
|
header_path = ""
|
||||||
|
if config.aot_inductor.compile_standalone:
|
||||||
|
# to link statically, we also need a header file
|
||||||
|
with open(
|
||||||
|
os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"csrc",
|
||||||
|
"inductor",
|
||||||
|
"aoti_runtime",
|
||||||
|
"model.h",
|
||||||
|
)
|
||||||
|
) as f:
|
||||||
|
class_name = f"AOTInductorModel{model_class_name}"
|
||||||
|
header_code = f.read()
|
||||||
|
|
||||||
|
# we replace like this to avoid replacing
|
||||||
|
# AOTInductorModelBase and AOTInductorModelKernelsBase
|
||||||
|
header_code = (
|
||||||
|
header_code.replace("<AOTInductorModel>", f"<{class_name}>")
|
||||||
|
.replace("AOTInductorModel(", f"{class_name}(")
|
||||||
|
.replace("AOTInductorModel :", f"{class_name} :")
|
||||||
|
)
|
||||||
|
_, header_path = write(
|
||||||
|
header_code,
|
||||||
|
"h",
|
||||||
|
specified_dir=specified_output_path,
|
||||||
|
key=f"{model_class_name}",
|
||||||
|
)
|
||||||
|
|
||||||
# Log the AOTInductor wrapper and kernel code, if needed.
|
# Log the AOTInductor wrapper and kernel code, if needed.
|
||||||
with tempfile.NamedTemporaryFile("w+") as t:
|
with tempfile.NamedTemporaryFile("w+") as t:
|
||||||
t.writelines((wrapper_code, "\n", kernel_code, "\n"))
|
t.writelines((wrapper_code, "\n", kernel_code, "\n"))
|
||||||
@ -1689,6 +1725,8 @@ class AotCodeCompiler:
|
|||||||
generated_files.append(wrapper_path)
|
generated_files.append(wrapper_path)
|
||||||
if not config.aot_inductor.package_cpp_only:
|
if not config.aot_inductor.package_cpp_only:
|
||||||
generated_files.append(kernel_path)
|
generated_files.append(kernel_path)
|
||||||
|
if config.aot_inductor.compile_standalone:
|
||||||
|
generated_files.append(header_path)
|
||||||
|
|
||||||
output_code_log.info("Wrapper code written to: %s", wrapper_path)
|
output_code_log.info("Wrapper code written to: %s", wrapper_path)
|
||||||
output_code_log.info("Kernel code written to: %s", kernel_path)
|
output_code_log.info("Kernel code written to: %s", kernel_path)
|
||||||
@ -1710,6 +1748,17 @@ class AotCodeCompiler:
|
|||||||
},
|
},
|
||||||
payload_fn=lambda: kernel_code,
|
payload_fn=lambda: kernel_code,
|
||||||
)
|
)
|
||||||
|
if config.aot_inductor.compile_standalone:
|
||||||
|
output_code_log.info("Header code written to: %s", header_path)
|
||||||
|
trace_structured(
|
||||||
|
"graph_dump",
|
||||||
|
lambda: {
|
||||||
|
"name": "inductor_aot_header_code",
|
||||||
|
"type": "cpp",
|
||||||
|
"filename": header_path,
|
||||||
|
},
|
||||||
|
payload_fn=lambda: header_code,
|
||||||
|
)
|
||||||
|
|
||||||
# We use a file lock below to protect FS operations. The lock file
|
# We use a file lock below to protect FS operations. The lock file
|
||||||
# is scoped to the 'key', so make sure the consts_s is protected
|
# is scoped to the 'key', so make sure the consts_s is protected
|
||||||
|
@ -22,7 +22,13 @@ from torch.utils._ordered_set import OrderedSet
|
|||||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||||
|
|
||||||
from .. import config, ir
|
from .. import config, ir
|
||||||
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
|
from ..utils import (
|
||||||
|
_align,
|
||||||
|
aoti_model_name_from_config,
|
||||||
|
DeferredLineBase,
|
||||||
|
LineContext,
|
||||||
|
normalize_name,
|
||||||
|
)
|
||||||
from ..virtualized import V
|
from ..virtualized import V
|
||||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||||
from .common import get_device_op_overrides, IndentedBuffer, Kernel
|
from .common import get_device_op_overrides, IndentedBuffer, Kernel
|
||||||
@ -58,6 +64,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
self.device = "cpu"
|
self.device = "cpu"
|
||||||
# must be initialized prior to calling super().__init__()
|
# must be initialized prior to calling super().__init__()
|
||||||
self.included_devices: OrderedSet[str] = OrderedSet()
|
self.included_devices: OrderedSet[str] = OrderedSet()
|
||||||
|
self.model_class_name_suffix = ""
|
||||||
|
if config.aot_inductor.compile_standalone:
|
||||||
|
self.model_class_name_suffix = aoti_model_name_from_config()
|
||||||
|
self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.declare = "auto "
|
self.declare = "auto "
|
||||||
self.declare_maybe_reference = "decltype(auto) "
|
self.declare_maybe_reference = "decltype(auto) "
|
||||||
@ -208,10 +218,16 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
self.add_device_include(self.device)
|
self.add_device_include(self.device)
|
||||||
|
|
||||||
if V.graph.aot_mode:
|
if V.graph.aot_mode:
|
||||||
with open(
|
if not config.aot_inductor.compile_standalone:
|
||||||
os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp")
|
with open(
|
||||||
) as f:
|
os.path.join(
|
||||||
self.header.splice(f.read())
|
os.path.dirname(__file__), "aoti_runtime", "interface.cpp"
|
||||||
|
)
|
||||||
|
) as f:
|
||||||
|
self.header.splice(f.read())
|
||||||
|
else:
|
||||||
|
# we produce a separate model header for each model in static linkage
|
||||||
|
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
|
||||||
self.header.splice("\n")
|
self.header.splice("\n")
|
||||||
|
|
||||||
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
||||||
@ -508,12 +524,12 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
|
|
||||||
if V.graph.is_const_graph:
|
if V.graph.is_const_graph:
|
||||||
self.prefix.splice(
|
self.prefix.splice(
|
||||||
"""
|
f"""
|
||||||
void AOTInductorModel::_const_run_impl(
|
void {self.aoti_model_class_name}::_const_run_impl(
|
||||||
std::vector<AtenTensorHandle>& output_handles,
|
std::vector<AtenTensorHandle>& output_handles,
|
||||||
DeviceStreamType stream,
|
DeviceStreamType stream,
|
||||||
AOTIProxyExecutorHandle proxy_executor
|
AOTIProxyExecutorHandle proxy_executor
|
||||||
) {
|
) {{
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -521,18 +537,18 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
# If we do not split the constant graph, we'll just create
|
# If we do not split the constant graph, we'll just create
|
||||||
# an empty implementation when wrapping the main module.
|
# an empty implementation when wrapping the main module.
|
||||||
self.prefix.splice(
|
self.prefix.splice(
|
||||||
"""
|
f"""
|
||||||
void AOTInductorModel::_const_run_impl(
|
void {self.aoti_model_class_name}::_const_run_impl(
|
||||||
std::vector<AtenTensorHandle>& output_handles,
|
std::vector<AtenTensorHandle>& output_handles,
|
||||||
DeviceStreamType stream,
|
DeviceStreamType stream,
|
||||||
AOTIProxyExecutorHandle proxy_executor
|
AOTIProxyExecutorHandle proxy_executor
|
||||||
) {}
|
) {{}}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
run_impl_proto = """
|
run_impl_proto = f"""
|
||||||
void AOTInductorModel::run_impl(
|
void {self.aoti_model_class_name}::run_impl(
|
||||||
AtenTensorHandle*
|
AtenTensorHandle*
|
||||||
input_handles, // array of input AtenTensorHandle; handles
|
input_handles, // array of input AtenTensorHandle; handles
|
||||||
// are stolen; the array itself is borrowed
|
// are stolen; the array itself is borrowed
|
||||||
@ -542,7 +558,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
// borrowed
|
// borrowed
|
||||||
DeviceStreamType stream,
|
DeviceStreamType stream,
|
||||||
AOTIProxyExecutorHandle proxy_executor
|
AOTIProxyExecutorHandle proxy_executor
|
||||||
) {
|
) {{
|
||||||
__check_inputs_outputs(input_handles, output_handles);
|
__check_inputs_outputs(input_handles, output_handles);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -734,7 +750,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
)
|
)
|
||||||
self.prefix.splice(
|
self.prefix.splice(
|
||||||
f"""
|
f"""
|
||||||
AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map,
|
{self.aoti_model_class_name}::{self.aoti_model_class_name}(std::shared_ptr<ConstantMap> constants_map,
|
||||||
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
|
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
|
||||||
const std::string& device_str,
|
const std::string& device_str,
|
||||||
std::optional<std::string> cubin_dir)
|
std::optional<std::string> cubin_dir)
|
||||||
@ -891,12 +907,12 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.prefix.splice(
|
self.prefix.splice(
|
||||||
"""
|
f"""
|
||||||
std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
|
std::unordered_map<std::string, AtenTensorHandle> {self.aoti_model_class_name}::const_run_impl(
|
||||||
DeviceStreamType stream,
|
DeviceStreamType stream,
|
||||||
AOTIProxyExecutorHandle proxy_executor,
|
AOTIProxyExecutorHandle proxy_executor,
|
||||||
bool initialization
|
bool initialization
|
||||||
) {
|
) {{
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
if not config.aot_inductor.use_runtime_constant_folding:
|
if not config.aot_inductor.use_runtime_constant_folding:
|
||||||
@ -1079,7 +1095,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
def generate_before_suffix(self, result):
|
def generate_before_suffix(self, result):
|
||||||
if not V.graph.is_const_graph:
|
if not V.graph.is_const_graph:
|
||||||
if V.graph.aot_mode:
|
if V.graph.aot_mode:
|
||||||
result.writeline("} // AOTInductorModel::run_impl")
|
result.writeline(f"}} // {self.aoti_model_class_name}::run_impl")
|
||||||
else:
|
else:
|
||||||
result.writeline("} // inductor_entry_impl")
|
result.writeline("} // inductor_entry_impl")
|
||||||
|
|
||||||
@ -1087,7 +1103,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
"""Generates the end of the code block, and any code needed to call it."""
|
"""Generates the end of the code block, and any code needed to call it."""
|
||||||
if V.graph.aot_mode:
|
if V.graph.aot_mode:
|
||||||
if V.graph.is_const_graph:
|
if V.graph.is_const_graph:
|
||||||
result.writeline("} // AOTInductorModel::_const_run_impl")
|
result.writeline(f"}} // {self.aoti_model_class_name}::_const_run_impl")
|
||||||
else:
|
else:
|
||||||
result.writeline("} // namespace torch::aot_inductor\n\n\n")
|
result.writeline("} // namespace torch::aot_inductor\n\n\n")
|
||||||
return
|
return
|
||||||
|
@ -2407,6 +2407,10 @@ def compile_fx(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if V.aot_compilation:
|
if V.aot_compilation:
|
||||||
|
from .utils import is_valid_aoti_model_name
|
||||||
|
|
||||||
|
is_valid_aoti_model_name()
|
||||||
|
|
||||||
with functorch_config.patch(unlift_effect_tokens=True):
|
with functorch_config.patch(unlift_effect_tokens=True):
|
||||||
gm, graph_signature = aot_export_module(
|
gm, graph_signature = aot_export_module(
|
||||||
model_,
|
model_,
|
||||||
|
@ -1421,6 +1421,13 @@ class aot_inductor:
|
|||||||
|
|
||||||
# If not None, the generated files with use this name in file stem.
|
# If not None, the generated files with use this name in file stem.
|
||||||
# If None, we will use a hash to name files.
|
# If None, we will use a hash to name files.
|
||||||
|
#
|
||||||
|
# If package_cpp_only, this name is also used for the target name in CMakelists.txt
|
||||||
|
# The default target name is "aoti_model"
|
||||||
|
#
|
||||||
|
# If compile_standalone, the aoti model class name is f"AOTInductorModel{name}"
|
||||||
|
#
|
||||||
|
# This name can only contain letters, numbers, and underscores.
|
||||||
model_name_for_generated_files: Optional[str] = None
|
model_name_for_generated_files: Optional[str] = None
|
||||||
|
|
||||||
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
|
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
|
||||||
|
@ -28,6 +28,7 @@ from torch._dynamo.utils import dynamo_timed
|
|||||||
from torch._inductor import config, exc
|
from torch._inductor import config, exc
|
||||||
from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA
|
from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA
|
||||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||||
|
from torch._inductor.utils import aoti_model_name_from_config
|
||||||
from torch.torch_version import TorchVersion
|
from torch.torch_version import TorchVersion
|
||||||
|
|
||||||
|
|
||||||
@ -1505,6 +1506,7 @@ class CppBuilder:
|
|||||||
self._aot_mode: bool = False
|
self._aot_mode: bool = False
|
||||||
|
|
||||||
self._name = name
|
self._name = name
|
||||||
|
self._target_name = aoti_model_name_from_config()
|
||||||
|
|
||||||
# Code start here, initial self internal variables firstly.
|
# Code start here, initial self internal variables firstly.
|
||||||
self._build_option = BuildOption
|
self._build_option = BuildOption
|
||||||
@ -1730,25 +1732,29 @@ class CppBuilder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
definitions = " ".join(self._build_option.get_definitions())
|
definitions = " ".join(self._build_option.get_definitions())
|
||||||
|
target_library_type = (
|
||||||
|
"STATIC" if config.aot_inductor.compile_standalone else "SHARED"
|
||||||
|
)
|
||||||
|
|
||||||
contents = textwrap.dedent(
|
contents = textwrap.dedent(
|
||||||
f"""
|
f"""
|
||||||
cmake_minimum_required(VERSION 3.27 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.27 FATAL_ERROR)
|
||||||
project(aoti_model LANGUAGES CXX)
|
project({self._target_name} LANGUAGES CXX)
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
# May need to point CMAKE_PREFIX_PATH to the right torch location
|
# May need to point CMAKE_PREFIX_PATH to the right torch location
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
# Set a shared library target
|
# Set a shared library target
|
||||||
add_library(aoti_model SHARED)
|
add_library({self._target_name} {target_library_type})
|
||||||
|
|
||||||
# Add macro definitions
|
# Add macro definitions
|
||||||
target_compile_definitions(aoti_model PRIVATE {definitions})
|
target_compile_definitions({self._target_name} PRIVATE {definitions})
|
||||||
|
|
||||||
# Add compile flags
|
# Add compile flags
|
||||||
target_compile_options(aoti_model PRIVATE {self._cflags_args})
|
target_compile_options({self._target_name} PRIVATE {self._cflags_args})
|
||||||
# Backend specific flags
|
# Backend specific flags
|
||||||
target_compile_options(aoti_model PRIVATE {self._passthrough_parameters_args} -c)
|
target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -1823,7 +1829,7 @@ class CppBuilder:
|
|||||||
# Remove the directory part of file_path
|
# Remove the directory part of file_path
|
||||||
src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name
|
src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name
|
||||||
with open(cmake_path, "a") as f:
|
with open(cmake_path, "a") as f:
|
||||||
f.write(f"target_sources(aoti_model PRIVATE {src_path})\n")
|
f.write(f"target_sources({self._target_name} PRIVATE {src_path})\n")
|
||||||
|
|
||||||
def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None:
|
def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None:
|
||||||
# TODO: make this work beyond CUDA
|
# TODO: make this work beyond CUDA
|
||||||
@ -1837,9 +1843,9 @@ class CppBuilder:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
f.write(contents)
|
f.write(contents)
|
||||||
f.write("add_dependencies(aoti_model ${KERNEL_TARGETS})\n")
|
f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n")
|
||||||
f.write(
|
f.write(
|
||||||
"target_link_libraries(aoti_model PRIVATE ${KERNEL_OBJECT_FILES})\n"
|
f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
|
def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
|
||||||
@ -1848,10 +1854,10 @@ class CppBuilder:
|
|||||||
contents = textwrap.dedent(
|
contents = textwrap.dedent(
|
||||||
f"""
|
f"""
|
||||||
# Add linker flags
|
# Add linker flags
|
||||||
target_link_options(aoti_model PRIVATE {lflags})
|
target_link_options({self._target_name} PRIVATE {lflags})
|
||||||
|
|
||||||
# Add libraries
|
# Add libraries
|
||||||
target_link_libraries(aoti_model PRIVATE {libs})
|
target_link_libraries({self._target_name} PRIVATE {libs})
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3308,3 +3308,38 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
|
|||||||
"Please set aot_inductor.package_cpp_only=True in your inductor config."
|
"Please set aot_inductor.package_cpp_only=True in your inductor config."
|
||||||
)
|
)
|
||||||
return config_patches
|
return config_patches
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_aoti_model_name() -> bool:
|
||||||
|
"""
|
||||||
|
Validates if a model name is suitable for use in code generation.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from torch._inductor import config
|
||||||
|
|
||||||
|
model_name = config.aot_inductor.model_name_for_generated_files
|
||||||
|
|
||||||
|
if model_name is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not isinstance(model_name, str):
|
||||||
|
raise ValueError("Invalid AOTI model name: Model name must be a string")
|
||||||
|
|
||||||
|
if model_name == "":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Can only contain alphanumeric characters and underscores
|
||||||
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", model_name):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid AOTI model name: Model name can only contain letters, numbers, and underscores"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def aoti_model_name_from_config() -> str:
|
||||||
|
from torch._inductor import config
|
||||||
|
|
||||||
|
model_name = config.aot_inductor.model_name_for_generated_files
|
||||||
|
model_name = "aoti_model" if model_name is None else model_name
|
||||||
|
return model_name
|
||||||
|
Reference in New Issue
Block a user