mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 02:04:53 +08:00
Compare commits
1 Commits
ciflow/tru
...
aoti_targe
| Author | SHA1 | Date | |
|---|---|---|---|
| 53f6698cc5 |
51
run.py
Normal file
51
run.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(16, 1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
# rm -r /tmp/torchinductor_shangdiy/
|
||||
with torch.no_grad():
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Model().to(device=device)
|
||||
example_inputs = (torch.randn(8, 10, device=device),)
|
||||
batch_dim = torch.export.Dim("batch", min=1, max=1024)
|
||||
# [Optional] Specify the first dimension of the input x as dynamic.
|
||||
exported = torch.export.export(
|
||||
model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
|
||||
)
|
||||
# [Note] In this example we directly feed the exported module to aoti_compile_and_package.
|
||||
# Depending on your use case, e.g. if your training platform and inference platform
|
||||
# are different, you may choose to save the exported model using torch.export.save and
|
||||
# then load it back using torch.export.load on your inference platform to run AOT compilation.
|
||||
output_path = torch._inductor.aoti_compile_and_package(
|
||||
exported,
|
||||
# [Optional] Specify the generated shared library path. If not specified,
|
||||
# the generated artifact is stored in your system temp directory.
|
||||
package_path=os.path.join(os.getcwd(), "model3.pt2"),
|
||||
inductor_configs={
|
||||
"aot_inductor.package_cpp_only": True, # optional, compile_standalone automatically makes it cpp_only
|
||||
"aot_inductor.cross_target_platform": "windows",
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor.dynamic_linkage": True,
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON,CPP",
|
||||
"max_autotune_conv_backends": "TRITON,CPP",
|
||||
"aot_inductor.model_name_for_generated_files": "model",
|
||||
},
|
||||
)
|
||||
1
setup.py
1
setup.py
@ -1693,6 +1693,7 @@ def main() -> None:
|
||||
"_inductor/codegen/*.h",
|
||||
"_inductor/codegen/aoti_runtime/*.h",
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/codegen/aoti_runtime/windows_symbol_exports.def",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
|
||||
@ -7189,14 +7189,18 @@ class AOTInductorLoggingTest(LoggingTestCase):
|
||||
|
||||
class TestAOTInductorConfig(TestCase):
|
||||
def test_no_compile_standalone(self):
|
||||
with config.patch({"aot_inductor.compile_standalone": False}):
|
||||
with config.patch({"aot_inductor_mode.compile_standalone": False}):
|
||||
result = maybe_aoti_standalone_config({})
|
||||
self.assertEqual(result, {})
|
||||
|
||||
def test_compile_standalone_sets_package_cpp(self):
|
||||
result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True})
|
||||
result = maybe_aoti_standalone_config(
|
||||
{
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
}
|
||||
)
|
||||
self.assertEqual(result["aot_inductor.package_cpp_only"], True)
|
||||
self.assertEqual(result["aot_inductor.compile_standalone"], True)
|
||||
self.assertEqual(result["aot_inductor_mode.compile_standalone"], True)
|
||||
self.assertEqual(result["aot_inductor.embed_kernel_binary"], True)
|
||||
self.assertEqual(
|
||||
result["aot_inductor.emit_multi_arch_kernel"], not torch.version.hip
|
||||
@ -7207,7 +7211,7 @@ class TestAOTInductorConfig(TestCase):
|
||||
|
||||
def test_compile_standalone_explicit_set(self):
|
||||
patches = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.package_cpp_only": True,
|
||||
"aot_inductor.embed_kernel_binary": True,
|
||||
"aot_inductor.emit_multi_arch_kernel": not torch.version.hip,
|
||||
@ -7218,7 +7222,7 @@ class TestAOTInductorConfig(TestCase):
|
||||
|
||||
def test_compile_standalone_package_cpp_false_raises(self):
|
||||
patches = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.package_cpp_only": False,
|
||||
}
|
||||
with self.assertRaises(RuntimeError):
|
||||
@ -7226,7 +7230,7 @@ class TestAOTInductorConfig(TestCase):
|
||||
|
||||
with config.patch({"aot_inductor.package_cpp_only": False}):
|
||||
patches = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
}
|
||||
with self.assertRaises(RuntimeError):
|
||||
maybe_aoti_standalone_config(patches)
|
||||
|
||||
@ -393,7 +393,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# Test compilation when no name is passed in
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
}
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmp_dir,
|
||||
@ -407,7 +407,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# Test compilation when model name is passed in
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.model_name_for_generated_files": "linear",
|
||||
}
|
||||
with (
|
||||
@ -422,7 +422,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# test invalid model name
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.model_name_for_generated_files": "linear/linear",
|
||||
}
|
||||
with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
|
||||
@ -448,7 +448,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# Test compilation when model name is passed in
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.model_name_for_generated_files": "cos",
|
||||
}
|
||||
with (
|
||||
|
||||
91
test/inductor/test_aot_inductor_windows.py
Normal file
91
test/inductor/test_aot_inductor_windows.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import sys
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
class Simple(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(16, 1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class TestAOTInductorWindowsCrossCompilation(TestCase):
|
||||
@requires_gpu()
|
||||
def test_simple_cpp_only(self):
|
||||
# rm -r /tmp/torchinductor_shangdiy/
|
||||
with torch.no_grad():
|
||||
device = "cuda"
|
||||
model = Simple().to(device=device)
|
||||
example_inputs = (torch.randn(8, 10, device=device),)
|
||||
batch_dim = torch.export.Dim("batch", min=1, max=1024)
|
||||
exported = torch.export.export(
|
||||
model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
|
||||
)
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
exported,
|
||||
inductor_configs={
|
||||
"aot_inductor.model_name_for_generated_files": "model",
|
||||
"aot_inductor.package_cpp_only": True,
|
||||
"aot_inductor.cross_target_platform": "windows",
|
||||
"aot_inductor.link_libtorch": False,
|
||||
# no fallback ops
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON,CPP",
|
||||
"max_autotune_conv_backends": "TRITON,CPP",
|
||||
},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(package_path, "r") as zf:
|
||||
zf.extractall(tmpdir)
|
||||
|
||||
makefile = open(
|
||||
f"{tmpdir}/model.wrapper/data/aotinductor/model/CMakeLists.txt"
|
||||
)
|
||||
makefile_content = makefile.read()
|
||||
|
||||
FileCheck().check("add_library(model SHARED)").check(
|
||||
"target_compile_definitions(model PRIVATE NOMINMAX "
|
||||
).check("USE_CUDA").check("target_compile_options(model").check(
|
||||
"""set_target_properties(model PROPERTIES SUFFIX ".pyd" """
|
||||
).check(
|
||||
"""LINK_FLAGS "/DEF:${CMAKE_CURRENT_SOURCE_DIR}/windows_symbol_exports.def" )"""
|
||||
).check(
|
||||
"target_sources(model PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/model.wrapper.cpp)"
|
||||
).check(
|
||||
"target_sources(model PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/model_consts.weights.cpp)"
|
||||
).check("embed_gpu_kernel(").check(
|
||||
"add_dependencies(model ${KERNEL_TARGETS})"
|
||||
).check(
|
||||
"target_link_libraries(model PRIVATE ${KERNEL_OBJECT_FILES})"
|
||||
).check(
|
||||
"target_link_options(model PRIVATE )" # no libtorch
|
||||
).check("target_link_libraries(model PRIVATE CUDA::cudart cuda)").run(
|
||||
makefile_content
|
||||
)
|
||||
|
||||
# TODO: actually compile the package in the test later in windows CI
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU or sys.platform == "darwin":
|
||||
run_tests(needs="filelock")
|
||||
@ -75,6 +75,7 @@ from torch._inductor.cpp_builder import (
|
||||
get_compiler_version_info,
|
||||
get_ld_and_objcopy,
|
||||
get_name_and_dir_from_output_file_path,
|
||||
is_target_windows,
|
||||
normalize_path_separator,
|
||||
run_asm_build_object,
|
||||
)
|
||||
@ -142,7 +143,6 @@ if TYPE_CHECKING:
|
||||
from .utils import InputType
|
||||
|
||||
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
LOCK_TIMEOUT = 600
|
||||
|
||||
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
|
||||
@ -393,7 +393,7 @@ class WritableTempFile:
|
||||
try:
|
||||
os.unlink(self.temp_file.name)
|
||||
except OSError as e:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
# On Windows, some case temp file is opened and fail to unlink. Need to ignore it.
|
||||
pass
|
||||
else:
|
||||
@ -447,7 +447,7 @@ def write_atomic(
|
||||
try:
|
||||
tmp_path.rename(target=path)
|
||||
except FileExistsError:
|
||||
if not _IS_WINDOWS:
|
||||
if not is_target_windows():
|
||||
raise
|
||||
# On Windows file exist is expected: https://docs.python.org/3/library/pathlib.html#pathlib.Path.rename
|
||||
# Below two lines code is equal to `tmp_path.rename(path)` on non-Windows OS.
|
||||
@ -1610,7 +1610,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
@functools.cache
|
||||
def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
|
||||
def get_module_ext_type() -> str:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
return ".pyd"
|
||||
else:
|
||||
return ".so"
|
||||
@ -1777,7 +1777,7 @@ class AotCodeCompiler:
|
||||
|
||||
header_code = ""
|
||||
header_path = ""
|
||||
if config.aot_inductor.compile_standalone:
|
||||
if not config.aot_inductor.dynamic_linkage:
|
||||
# to link statically, we also need a header file
|
||||
with open(
|
||||
os.path.join(
|
||||
@ -1788,7 +1788,7 @@ class AotCodeCompiler:
|
||||
"model.h",
|
||||
)
|
||||
) as f:
|
||||
# model_name_for_generated_files is guaranteed to be non-empty when compile_standalone
|
||||
# model_name_for_generated_files is guaranteed to be non-empty when dynamic_linkage is False
|
||||
model_class_name = config.aot_inductor.model_name_for_generated_files
|
||||
class_name = f"AOTInductorModel{model_class_name}"
|
||||
header_code = f.read()
|
||||
@ -1827,7 +1827,7 @@ class AotCodeCompiler:
|
||||
generated_files.append(wrapper_path)
|
||||
if not config.aot_inductor.package_cpp_only:
|
||||
generated_files.append(kernel_path)
|
||||
if config.aot_inductor.compile_standalone:
|
||||
if not config.aot_inductor.dynamic_linkage:
|
||||
generated_files.append(header_path)
|
||||
|
||||
output_code_log.info("Wrapper code written to: %s", wrapper_path)
|
||||
@ -1850,7 +1850,7 @@ class AotCodeCompiler:
|
||||
},
|
||||
payload_fn=lambda: kernel_code,
|
||||
)
|
||||
if config.aot_inductor.compile_standalone:
|
||||
if not config.aot_inductor.dynamic_linkage:
|
||||
output_code_log.info("Header code written to: %s", header_path)
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
@ -1872,9 +1872,12 @@ class AotCodeCompiler:
|
||||
specified_sub_dir.mkdir(exist_ok=True)
|
||||
cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt")
|
||||
|
||||
def _compile_consts(consts: bytes, platform: str) -> str:
|
||||
def _compile_consts(consts: bytes, platform: str) -> tuple[str, Optional[str]]:
|
||||
# Load from aot_inductor, and update the value on demand.
|
||||
use_asm_build: bool = config.aot_inductor.use_consts_asm_build
|
||||
use_asm_build: bool = (
|
||||
config.aot_inductor.use_consts_asm_build
|
||||
and config.aot_inductor.cross_target_platform != "windows"
|
||||
)
|
||||
|
||||
if platform == "linux":
|
||||
if graph.mutated_buffers & OrderedSet(graph.constants.keys()):
|
||||
@ -1976,7 +1979,7 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
|
||||
Linux: Added '-pedantic' to disable zero-sized arrays in C++ compiler
|
||||
Windows: MSVC naturally rejects zero-sized arrays by default
|
||||
"""
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
# Windows ml64 is max support align to 16, but it is no effect to zero size data.
|
||||
asm_code = """
|
||||
option casemap:none
|
||||
@ -2020,7 +2023,7 @@ end
|
||||
consts_code,
|
||||
code_ext,
|
||||
specified_dir=str(specified_sub_dir),
|
||||
key=config.aot_inductor.model_name_for_generated_files,
|
||||
key=f"{config.aot_inductor.model_name_for_generated_files}_consts",
|
||||
)
|
||||
consts_s = Path(consts_s)
|
||||
object_build_options = CppTorchDeviceOptions(
|
||||
@ -2038,7 +2041,7 @@ end
|
||||
consts_o = object_builder.get_target_file_path()
|
||||
if use_asm_build is False and is_zero_size_consts:
|
||||
run_asm_build_object(str(consts_s), consts_o, str(consts_s.parent))
|
||||
else:
|
||||
elif config.aot_inductor.cross_target_platform != "windows":
|
||||
object_builder.build()
|
||||
|
||||
if is_large_consts and use_asm_build:
|
||||
@ -2058,10 +2061,12 @@ end
|
||||
rc = f.write(consts[pos:])
|
||||
pos += rc
|
||||
|
||||
# Remove the .S file to save space
|
||||
os.remove(consts_s)
|
||||
if config.aot_inductor.cross_target_platform != "windows":
|
||||
# Remove the .S file to save space
|
||||
os.remove(consts_s)
|
||||
return consts_o, None
|
||||
|
||||
return consts_o
|
||||
return consts_o, str(consts_s)
|
||||
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
@ -2199,7 +2204,7 @@ end
|
||||
)
|
||||
|
||||
# potentially, precompile the AOT header for this device
|
||||
if config.aot_inductor.precompile_headers and not _IS_WINDOWS:
|
||||
if config.aot_inductor.precompile_headers and not is_target_windows():
|
||||
header_file = _get_cpp_wrapper_header(
|
||||
device_type, aot_mode=graph.aot_mode
|
||||
)
|
||||
@ -2248,6 +2253,33 @@ end
|
||||
wrapper_builder.save_compile_cmd_to_cmake(cmake_path, device_type)
|
||||
wrapper_builder.save_src_to_cmake(cmake_path, wrapper_path)
|
||||
generated_files.append(cmake_path)
|
||||
|
||||
if is_target_windows():
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"csrc",
|
||||
"inductor",
|
||||
"aoti_runtime",
|
||||
"windows_symbol_exports.def",
|
||||
)
|
||||
) as f:
|
||||
# model_name_for_generated_files is guaranteed to be non-empty when dynamic_linkage is False
|
||||
assert (
|
||||
config.aot_inductor.model_name_for_generated_files
|
||||
is not None
|
||||
)
|
||||
windows_symbol_exports = f.read().replace(
|
||||
"{config.aot_inductor.model_name_for_generated_files}",
|
||||
config.aot_inductor.model_name_for_generated_files,
|
||||
)
|
||||
_, expors_path = write(
|
||||
windows_symbol_exports,
|
||||
"def",
|
||||
specified_dir=str(specified_sub_dir),
|
||||
key="windows_symbol_exports",
|
||||
)
|
||||
generated_files.append(expors_path)
|
||||
else:
|
||||
try:
|
||||
wrapper_builder.build()
|
||||
@ -2268,7 +2300,7 @@ end
|
||||
)
|
||||
aot_constants = struct.pack("qq", consts_size + 8, magic_number)
|
||||
|
||||
consts_o = _compile_consts(aot_constants, sys.platform)
|
||||
consts_o, consts_asm = _compile_consts(aot_constants, sys.platform)
|
||||
custom_obj_idx = 0
|
||||
# Note that custom_objs_config.json file is different from the model_constants_config.json file produced
|
||||
# in package_sigmoid(). The keys in custom_objs_config.json directly correspond to the arg name in extern
|
||||
@ -2318,8 +2350,12 @@ end
|
||||
)
|
||||
|
||||
cubins_o = []
|
||||
asm_files = []
|
||||
if not _IS_WINDOWS:
|
||||
asm_files = [
|
||||
value["asm"]
|
||||
for value in CudaKernelParamCache.cache.values()
|
||||
if "asm" in value
|
||||
]
|
||||
if not is_target_windows():
|
||||
ld, objcopy = get_ld_and_objcopy(use_relative_path)
|
||||
kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {})
|
||||
for kernel_name, value in CudaKernelParamCache.cache.items():
|
||||
@ -2328,9 +2364,6 @@ end
|
||||
# than what the current graph uses
|
||||
continue
|
||||
|
||||
if asm_file := value["asm"]:
|
||||
asm_files.append(asm_file)
|
||||
|
||||
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
|
||||
if (
|
||||
config.aot_inductor.emit_multi_arch_kernel
|
||||
@ -2338,7 +2371,7 @@ end
|
||||
):
|
||||
current_arch = _nvcc_arch_as_compile_option()
|
||||
cmd = (
|
||||
f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
|
||||
f"{_cuda_compiler()} -fatbin {value['asm']} -o {cubin_file} "
|
||||
# Triton only allows generating PTX version as same as the current arch
|
||||
f"-gencode arch=compute_{current_arch},code=compute_{current_arch} "
|
||||
# Include SASS for the current specific arch
|
||||
@ -2418,14 +2451,22 @@ end
|
||||
f_weights.write(struct.pack("q", magic_number))
|
||||
|
||||
generated_files.append(weight_file)
|
||||
elif config.aot_inductor.cross_target_platform == "windows":
|
||||
assert consts_asm is not None
|
||||
generated_files.append(consts_asm)
|
||||
so_builder.save_src_to_cmake(cmake_path, consts_asm)
|
||||
else:
|
||||
# TODO: unify to always use mmap_weights
|
||||
generated_files.append(consts_o)
|
||||
so_builder.save_src_to_cmake(cmake_path, consts_o)
|
||||
|
||||
if config.aot_inductor.emit_multi_arch_kernel:
|
||||
if (
|
||||
config.aot_inductor.emit_multi_arch_kernel
|
||||
or config.aot_inductor.cross_target_platform == "windows"
|
||||
):
|
||||
so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files)
|
||||
generated_files.extend(asm_files)
|
||||
|
||||
else:
|
||||
obj_srcs = [*gpu_kernels_o, *cubins_o]
|
||||
generated_files.extend(obj_srcs)
|
||||
@ -2446,7 +2487,7 @@ end
|
||||
def get_page_size() -> int:
|
||||
# Don't use resource.getpagesize() on Windows, as it is a Unix specific package
|
||||
# as seen in https://docs.python.org/2/library/resource.html
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
from ctypes import ( # type: ignore[attr-defined]
|
||||
byref,
|
||||
Structure,
|
||||
@ -2574,7 +2615,7 @@ def _precompile_header(
|
||||
hashable_cmd_line: str,
|
||||
**compile_command: Any,
|
||||
) -> str:
|
||||
assert not _IS_WINDOWS, (
|
||||
assert not is_target_windows(), (
|
||||
"CppBuilder does not currently support precompiling on Windows!"
|
||||
)
|
||||
|
||||
@ -2759,7 +2800,7 @@ class CppCodeCache:
|
||||
lib = None
|
||||
|
||||
# if requested, pre-compile any headers
|
||||
if config.cpp_cache_precompile_headers and not _IS_WINDOWS:
|
||||
if config.cpp_cache_precompile_headers and not is_target_windows():
|
||||
if header := cls._get_uncompiled_header(device_type):
|
||||
main_build_option.precompiled_header = _precompile_header(
|
||||
header,
|
||||
@ -3737,11 +3778,9 @@ def _nvcc_host_compiler_options() -> list[str]:
|
||||
|
||||
def _nvcc_arch_as_compile_option() -> str:
|
||||
arch = cuda_env.get_cuda_arch()
|
||||
if arch == "90":
|
||||
if arch in ("90", "100", "120"):
|
||||
# Required by cutlass compilation.
|
||||
return "90a"
|
||||
if arch == "100":
|
||||
return "100a"
|
||||
return f"{arch}a"
|
||||
return arch
|
||||
|
||||
|
||||
|
||||
@ -60,9 +60,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
# must be initialized prior to calling super().__init__()
|
||||
self.included_devices: OrderedSet[str] = OrderedSet()
|
||||
self.model_class_name_suffix = (
|
||||
config.aot_inductor.model_name_for_generated_files
|
||||
if config.aot_inductor.compile_standalone
|
||||
else ""
|
||||
""
|
||||
if config.aot_inductor.dynamic_linkage
|
||||
else config.aot_inductor.model_name_for_generated_files
|
||||
)
|
||||
self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}"
|
||||
|
||||
@ -222,7 +222,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
self.add_device_include(self.device)
|
||||
|
||||
if V.graph.aot_mode:
|
||||
if not config.aot_inductor.compile_standalone:
|
||||
if config.aot_inductor.dynamic_linkage:
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "aoti_runtime", "interface.cpp"
|
||||
|
||||
@ -1585,7 +1585,14 @@ class aot_inductor:
|
||||
# custom op libs that have implemented C shim wrappers
|
||||
custom_op_libs: Optional[list[str]] = None
|
||||
|
||||
compile_standalone: bool = False
|
||||
# If set to "windows", we will compile from WSL and generate a C++ project file for
|
||||
# further Windows native compilation
|
||||
# Only works with package_cpp_only=True
|
||||
cross_target_platform: Optional[str] = None
|
||||
|
||||
# If package_cpp_only is True, whether cpp files will be compiled to a
|
||||
# dynamically linked library or static linked library
|
||||
dynamic_linkage: bool = True
|
||||
|
||||
# Whether to enable link-time-optimization
|
||||
enable_lto = os.environ.get("AOT_INDUCTOR_ENABLE_LTO", "0") == "1"
|
||||
@ -1601,6 +1608,22 @@ class aot_inductor:
|
||||
# TODO: should consolidate this flag with compile_standalone
|
||||
libtorch_free_headers: Optional[list[str]] = None
|
||||
|
||||
# compile to TorchStandalone
|
||||
compile_with_torchstandalone: int = False
|
||||
|
||||
|
||||
# a convenient class that automatically sets a group of the configs in aot_inductor
|
||||
# it should only control the flags in aot_inductor.
|
||||
# it should not do anything else.
|
||||
class aot_inductor_mode:
|
||||
# dynamic_linkage=False
|
||||
# link_libtorch=False
|
||||
# package_cpp_only=True
|
||||
# embed_kernel_binary=True
|
||||
# emit_multi_arch_kernel=True
|
||||
# compile_with_torchstandalone=True
|
||||
compile_standalone: bool = False
|
||||
|
||||
|
||||
class cuda:
|
||||
"""Settings for cuda backend, today this consists of cutlass"""
|
||||
|
||||
@ -65,14 +65,27 @@ _LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld")
|
||||
# initialize variables for compilation
|
||||
_IS_LINUX = sys.platform.startswith("linux")
|
||||
_IS_MACOS = sys.platform.startswith("darwin")
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
SUBPROCESS_DECODE_ARGS = ("utf-8",) if _IS_WINDOWS else ()
|
||||
SUBPROCESS_DECODE_ARGS = (
|
||||
("utf-8",)
|
||||
if sys.platform == "win32" or config.aot_inductor.cross_target_platform == "windows"
|
||||
else ()
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================== toolchain ===============================
|
||||
def is_target_windows() -> bool:
|
||||
return (
|
||||
sys.platform == "win32"
|
||||
or config.aot_inductor.cross_target_platform == "windows"
|
||||
)
|
||||
|
||||
|
||||
def is_compiling_on_windows() -> bool:
|
||||
return sys.platform == "win32"
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def cpp_compiler_search(search: str) -> str:
|
||||
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
|
||||
@ -133,6 +146,9 @@ def check_compiler_exist_windows(compiler: str) -> None:
|
||||
"""
|
||||
Check if compiler is ready, in case end user not activate MSVC environment.
|
||||
"""
|
||||
if config.aot_inductor.cross_target_platform == "windows":
|
||||
# Do not check compiler if cross target platform is windows.
|
||||
return
|
||||
try:
|
||||
subprocess.check_output([compiler, "/help"], stderr=subprocess.STDOUT)
|
||||
except FileNotFoundError as exc:
|
||||
@ -332,7 +348,7 @@ def check_msvc_cl_language_id(compiler: str) -> None:
|
||||
|
||||
|
||||
def get_cpp_compiler() -> str:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
compiler = os.environ.get("CXX", "cl")
|
||||
compiler = normalize_path_separator(compiler)
|
||||
check_compiler_exist_windows(compiler)
|
||||
@ -349,7 +365,7 @@ def get_cpp_compiler() -> str:
|
||||
|
||||
|
||||
def get_ld_and_objcopy(use_relative_path: bool) -> tuple[str, str]:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
raise RuntimeError("Windows is not supported yet.")
|
||||
else:
|
||||
if config.is_fbcode():
|
||||
@ -403,7 +419,7 @@ def _is_clang(cpp_compiler: str) -> bool:
|
||||
# Mac OS apple clang maybe named as gcc, need check compiler info.
|
||||
if sys.platform == "darwin":
|
||||
return _is_apple_clang(cpp_compiler)
|
||||
elif _IS_WINDOWS:
|
||||
elif is_target_windows():
|
||||
# clang suite have many compilers, and only clang-cl is supported.
|
||||
if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler):
|
||||
raise RuntimeError(
|
||||
@ -423,7 +439,7 @@ def _is_gcc(cpp_compiler: str) -> bool:
|
||||
|
||||
@functools.cache
|
||||
def _is_msvc_cl(cpp_compiler: str) -> bool:
|
||||
if not _IS_WINDOWS:
|
||||
if not is_compiling_on_windows():
|
||||
return False
|
||||
|
||||
try:
|
||||
@ -445,7 +461,7 @@ def _is_intel_compiler(cpp_compiler: str) -> bool:
|
||||
"""
|
||||
On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor.
|
||||
"""
|
||||
min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0"
|
||||
min_version = "2024.2.1" if is_target_windows() else "0.0.0"
|
||||
if compiler_version < TorchVersion(min_version):
|
||||
raise RuntimeError(
|
||||
f"Intel Compiler error: less than minimal version {min_version}."
|
||||
@ -461,7 +477,7 @@ def _is_intel_compiler(cpp_compiler: str) -> bool:
|
||||
)
|
||||
is_intel_compiler = "Intel" in output_msg.splitlines()[0]
|
||||
if is_intel_compiler:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
if re.search(r"((icx$)|(icx-cc$))", cpp_compiler):
|
||||
raise RuntimeError(
|
||||
"Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)."
|
||||
@ -594,7 +610,7 @@ def run_compile_cmd(cmd_line: str, cwd: str) -> None:
|
||||
|
||||
|
||||
def normalize_path_separator(orig_path: str) -> str:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
return orig_path.replace(os.sep, "/")
|
||||
return orig_path
|
||||
|
||||
@ -719,14 +735,14 @@ class BuildOptionsBase:
|
||||
|
||||
|
||||
def _get_warning_all_cflag(warning_all: bool = True) -> list[str]:
|
||||
if not _IS_WINDOWS:
|
||||
if not is_target_windows():
|
||||
return ["Wall"] if warning_all else []
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
"""
|
||||
On Windows, only c++20 can support `std::enable_if_t`.
|
||||
Ref: https://learn.microsoft.com/en-us/cpp/overview/cpp-conformance-improvements-2019?view=msvc-170#checking-for-abstract-class-types # noqa: B950
|
||||
@ -741,7 +757,7 @@ def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
|
||||
|
||||
|
||||
def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
cflags = [
|
||||
"wd4819",
|
||||
"wd4251",
|
||||
@ -778,7 +794,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
|
||||
|
||||
def _get_os_related_cpp_definitions(cpp_compiler: str) -> list[str]:
|
||||
os_definitions: list[str] = []
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
# On Windows, we need disable min/max macro to avoid C2589 error, as PyTorch CMake:
|
||||
# https://github.com/pytorch/pytorch/blob/9a41570199155eee92ebd28452a556075e34e1b4/CMakeLists.txt#L1118-L1119
|
||||
os_definitions.append("NOMINMAX")
|
||||
@ -788,7 +804,7 @@ def _get_os_related_cpp_definitions(cpp_compiler: str) -> list[str]:
|
||||
|
||||
|
||||
def _get_ffast_math_flags() -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
flags = []
|
||||
else:
|
||||
# ffast-math is equivalent to these flags as in
|
||||
@ -825,7 +841,7 @@ def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]:
|
||||
cflags: list[str] = []
|
||||
ldflags: list[str] = []
|
||||
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
cflags = ["ZI", "_DEBUG"]
|
||||
ldflags = ["DEBUG", "ASSEMBLYDEBUG ", "OPT:REF", "OPT:ICF"]
|
||||
else:
|
||||
@ -848,19 +864,19 @@ def _get_optimization_cflags(
|
||||
|
||||
if b_debug_build:
|
||||
cflags, ldflags = _get_inductor_debug_symbol_cflags()
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
cflags += ["Od", "Ob0", "Oy-"]
|
||||
else:
|
||||
cflags.append("O0")
|
||||
else:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
cflags = ["O1" if min_optimize else "O2"]
|
||||
else:
|
||||
cflags = [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"]
|
||||
|
||||
cflags += _get_ffast_math_flags()
|
||||
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
pass
|
||||
else:
|
||||
if sys.platform != "darwin":
|
||||
@ -882,7 +898,7 @@ def _get_optimization_cflags(
|
||||
|
||||
|
||||
def _get_shared_cflags(do_link: bool) -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
"""
|
||||
MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
|
||||
https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170
|
||||
@ -923,7 +939,11 @@ def get_cpp_options(
|
||||
|
||||
definitions += _get_os_related_cpp_definitions(cpp_compiler)
|
||||
|
||||
if not _IS_WINDOWS and config.aot_inductor.enable_lto and _is_clang(cpp_compiler):
|
||||
if (
|
||||
not is_target_windows()
|
||||
and config.aot_inductor.enable_lto
|
||||
and _is_clang(cpp_compiler)
|
||||
):
|
||||
ldflags.append("fuse-ld=lld")
|
||||
ldflags.append("flto=thin")
|
||||
|
||||
@ -1005,7 +1025,7 @@ def _use_custom_generated_macros() -> list[str]:
|
||||
|
||||
|
||||
def _use_fb_internal_macros() -> list[str]:
|
||||
if not _IS_WINDOWS:
|
||||
if not is_target_windows():
|
||||
if config.is_fbcode():
|
||||
fb_internal_macros = [
|
||||
"C10_USE_GLOG",
|
||||
@ -1027,7 +1047,7 @@ def _setup_standard_sys_libs(
|
||||
cflags: list[str] = []
|
||||
include_dirs: list[str] = []
|
||||
passthrough_args: list[str] = []
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
return cflags, include_dirs, passthrough_args
|
||||
|
||||
if config.is_fbcode():
|
||||
@ -1098,7 +1118,7 @@ def _get_torch_related_args(
|
||||
else:
|
||||
libraries_dirs = []
|
||||
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows() and config.aot_inductor.link_libtorch:
|
||||
libraries.append("sleef")
|
||||
|
||||
return include_dirs, libraries_dirs, libraries
|
||||
@ -1120,12 +1140,12 @@ def _get_python_include_dirs() -> list[str]:
|
||||
def _get_python_related_args() -> tuple[list[str], list[str]]:
|
||||
python_include_dirs = _get_python_include_dirs()
|
||||
python_include_path = sysconfig.get_path(
|
||||
"include", scheme="nt" if _IS_WINDOWS else "posix_prefix"
|
||||
"include", scheme="nt" if is_target_windows() else "posix_prefix"
|
||||
)
|
||||
if python_include_path is not None:
|
||||
python_include_dirs.append(python_include_path)
|
||||
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
python_lib_path = [
|
||||
str(
|
||||
(
|
||||
@ -1273,7 +1293,7 @@ def _get_openmp_args(
|
||||
|
||||
# if openmp is still not available, we let the compiler to have a try,
|
||||
# and raise error together with instructions at compilation error later
|
||||
elif _IS_WINDOWS:
|
||||
elif is_target_windows():
|
||||
"""
|
||||
On Windows, `clang` and `icx` have their specific openmp implenmention.
|
||||
And the openmp lib is in compiler's some sub-directory.
|
||||
@ -1419,6 +1439,7 @@ def get_cpp_torch_options(
|
||||
ldflags = omp_ldflags
|
||||
libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths
|
||||
libraries = torch_libraries + omp_lib
|
||||
|
||||
passthrough_args = (
|
||||
sys_libs_passthrough_args + isa_ps_args_build_flags + omp_passthrough_args
|
||||
)
|
||||
@ -1585,6 +1606,8 @@ def get_cpp_torch_device_options(
|
||||
libraries += ["c10_hip", "torch_hip"]
|
||||
definitions.append(" __HIP_PLATFORM_AMD__")
|
||||
else:
|
||||
if is_target_windows():
|
||||
libraries += ["CUDA::cudart"]
|
||||
if config.is_fbcode() or not link_libtorch:
|
||||
libraries += ["cuda"]
|
||||
else:
|
||||
@ -1597,7 +1620,7 @@ def get_cpp_torch_device_options(
|
||||
"Intel GPU driver is not properly installed, please follow the instruction "
|
||||
"in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support."
|
||||
)
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
ze_root = os.getenv("LEVEL_ZERO_V1_SDK_PATH")
|
||||
if ze_root is None:
|
||||
raise OSError(xpu_error_string)
|
||||
@ -1766,26 +1789,26 @@ class CppBuilder:
|
||||
|
||||
@staticmethod
|
||||
def __get_python_module_flags() -> tuple[str, str]:
|
||||
extension = ".pyd" if _IS_WINDOWS else ".so"
|
||||
output_flags = "/Fe" if _IS_WINDOWS else "-o"
|
||||
extension = ".pyd" if is_target_windows() else ".so"
|
||||
output_flags = "/Fe" if is_target_windows() else "-o"
|
||||
return extension, output_flags
|
||||
|
||||
@staticmethod
|
||||
def __get_object_flags() -> tuple[str, str]:
|
||||
extension = ".obj" if _IS_WINDOWS else ".o"
|
||||
output_flags = "/c /Fo" if _IS_WINDOWS else "-c -o" # codespell:ignore
|
||||
extension = ".obj" if is_target_windows() else ".o"
|
||||
output_flags = "/c /Fo" if is_target_windows() else "-c -o" # codespell:ignore
|
||||
return extension, output_flags
|
||||
|
||||
@staticmethod
|
||||
def __get_precompiled_header_flags() -> tuple[str, str]:
|
||||
extension = ".pch" if _IS_WINDOWS or not is_gcc() else ".gch"
|
||||
output_flags = "/Fp" if _IS_WINDOWS else "-o"
|
||||
extension = ".pch" if is_target_windows() or not is_gcc() else ".gch"
|
||||
output_flags = "/Fp" if is_target_windows() else "-o"
|
||||
return extension, output_flags
|
||||
|
||||
@staticmethod
|
||||
def __get_preprocessor_output_flags() -> tuple[str, str]:
|
||||
extension = ".i"
|
||||
output_flags = "/EP /P" if _IS_WINDOWS else "-E -P -o"
|
||||
output_flags = "/EP /P" if is_target_windows() else "-E -P -o"
|
||||
return extension, output_flags
|
||||
|
||||
def __init__(
|
||||
@ -1837,7 +1860,7 @@ class CppBuilder:
|
||||
# MSVC produces two files when precompiling: the actual .pch file, as well as an
|
||||
# object file which must be linked into the final library. This class assumes
|
||||
# only one output file of note, so for now we'll error out here.
|
||||
assert not _IS_WINDOWS or not self._precompiling, (
|
||||
assert not is_target_windows() or not self._precompiling, (
|
||||
"Cannot currently precompile headers on Windows!"
|
||||
)
|
||||
|
||||
@ -1856,7 +1879,7 @@ class CppBuilder:
|
||||
if self._use_relative_path
|
||||
else self._target_file
|
||||
)
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
if self._preprocessing:
|
||||
# The target file name is automatically determined by MSVC.
|
||||
self._output = output_flags
|
||||
@ -1883,19 +1906,19 @@ class CppBuilder:
|
||||
self._sources_args = " ".join(sources)
|
||||
|
||||
for cflag in BuildOption.get_cflags():
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
self._cflags_args += f"/{cflag} "
|
||||
else:
|
||||
self._cflags_args += f"-{cflag} "
|
||||
|
||||
for definition in BuildOption.get_definitions():
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
self._definitions_args += f"/D {definition} "
|
||||
else:
|
||||
self._definitions_args += f"-D {definition} "
|
||||
|
||||
if precompiled_header := BuildOption.precompiled_header:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
log.warning(
|
||||
"Precompiled header support for MSVC is currently unavailable; ignoring %s",
|
||||
precompiled_header,
|
||||
@ -1904,25 +1927,25 @@ class CppBuilder:
|
||||
self._include_dirs_args = f"-include {precompiled_header} "
|
||||
|
||||
for inc_dir in BuildOption.get_include_dirs():
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
self._include_dirs_args += f'/I "{inc_dir}" '
|
||||
else:
|
||||
self._include_dirs_args += f"-I{shlex.quote(inc_dir)} "
|
||||
|
||||
for ldflag in BuildOption.get_ldflags():
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
self._ldflags_args += f"/{ldflag} "
|
||||
else:
|
||||
self._ldflags_args += f"-{ldflag} "
|
||||
|
||||
for lib_dir in BuildOption.get_libraries_dirs():
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" '
|
||||
else:
|
||||
self._libraries_dirs_args += f"-L{lib_dir} "
|
||||
|
||||
for lib in BuildOption.get_libraries():
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
self._libraries_args += f'"{lib}.lib" '
|
||||
else:
|
||||
self._libraries_args += f"-l{lib} "
|
||||
@ -1943,7 +1966,7 @@ class CppBuilder:
|
||||
passthrough_args: str,
|
||||
output: str,
|
||||
) -> str:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
# https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704
|
||||
# https://stackoverflow.com/a/31566153
|
||||
cmd = (
|
||||
@ -2043,7 +2066,7 @@ class CppBuilder:
|
||||
|
||||
definitions = " ".join(self._build_option.get_definitions())
|
||||
target_library_type = (
|
||||
"STATIC" if config.aot_inductor.compile_standalone else "SHARED"
|
||||
"STATIC" if not config.aot_inductor.dynamic_linkage else "SHARED"
|
||||
)
|
||||
|
||||
contents = textwrap.dedent(
|
||||
@ -2058,10 +2081,7 @@ class CppBuilder:
|
||||
"""
|
||||
)
|
||||
|
||||
if (
|
||||
not config.aot_inductor.compile_standalone
|
||||
or config.test_configs.use_libtorch
|
||||
):
|
||||
if config.aot_inductor.link_libtorch or config.test_configs.use_libtorch:
|
||||
# When compile_standalone is True, the generated cpp project should
|
||||
# not use Torch. But for unit testing purpose, we need to use Torch here.
|
||||
contents += textwrap.dedent(
|
||||
@ -2071,24 +2091,7 @@ class CppBuilder:
|
||||
|
||||
"""
|
||||
)
|
||||
# flags and macros here are mostly CPU specific. Not emitting them for GPU models
|
||||
# will make the generated CMake file more portable and won't really hurt performance.
|
||||
# NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may
|
||||
# be still needed.
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
# Add macro definitions
|
||||
target_compile_definitions({self._target_name} PRIVATE {definitions})
|
||||
|
||||
# Add compile flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._cflags_args})
|
||||
|
||||
# Backend-specific flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
|
||||
|
||||
"""
|
||||
)
|
||||
else:
|
||||
elif config.aot_inductor.compile_with_torchstandalone:
|
||||
# When compile_standalone is True, use TorchStandalone instead of Torch
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
@ -2100,73 +2103,148 @@ class CppBuilder:
|
||||
"""
|
||||
)
|
||||
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
# Add macro definitions
|
||||
target_compile_definitions({self._target_name} PRIVATE {definitions})
|
||||
|
||||
# Add compile flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._cflags_args})
|
||||
|
||||
# Backend-specific flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
if device_type == "cuda" and torch.version.hip is None:
|
||||
from torch._inductor.codecache import _nvcc_arch_as_compile_option
|
||||
|
||||
current_arch = _nvcc_arch_as_compile_option()
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
"""
|
||||
enable_language(CUDA)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}})
|
||||
target_compile_definitions({self._target_name} PRIVATE USE_CUDA)
|
||||
target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static)
|
||||
|
||||
find_program(OBJCOPY_EXECUTABLE objcopy)
|
||||
if(NOT OBJCOPY_EXECUTABLE)
|
||||
message(FATAL_ERROR "objcopy not found. Cannot embed fatbin as object file")
|
||||
endif()
|
||||
|
||||
set(KERNEL_TARGETS "")
|
||||
set(KERNEL_OBJECT_FILES "")
|
||||
# Function to embed a single kernel
|
||||
function(embed_gpu_kernel KERNEL_NAME PTX_FILE)
|
||||
set(FATBIN_BASENAME ${{KERNEL_NAME}}.fatbin)
|
||||
set(FATBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{FATBIN_BASENAME}})
|
||||
set(OBJECT_BASENAME ${{KERNEL_NAME}}.fatbin.o)
|
||||
set(OBJECT_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{OBJECT_BASENAME}})
|
||||
|
||||
# --- Define UNIQUE C symbol names ---
|
||||
set(SYMBOL_START __${{KERNEL_NAME}}_start)
|
||||
set(SYMBOL_END __${{KERNEL_NAME}}_end)
|
||||
set(SYMBOL_SIZE __${{KERNEL_NAME}}_size)
|
||||
string(REGEX REPLACE "[^a-zA-Z0-9]" "_" MANGLED_BASENAME ${{FATBIN_FILE}})
|
||||
set(OBJCOPY_START_SYM _binary_${{MANGLED_BASENAME}}_start)
|
||||
set(OBJCOPY_END_SYM _binary_${{MANGLED_BASENAME}}_end)
|
||||
set(OBJCOPY_SIZE_SYM _binary_${{MANGLED_BASENAME}}_size)
|
||||
|
||||
# --- PTX to FATBIN Command & Target ---
|
||||
add_custom_command(
|
||||
OUTPUT ${{FATBIN_FILE}}
|
||||
COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}}
|
||||
-gencode arch=compute_{current_arch},code=compute_{current_arch}
|
||||
-gencode arch=compute_{current_arch},code=sm_{current_arch}
|
||||
DEPENDS ${{PTX_FILE}}
|
||||
)
|
||||
|
||||
# --- FATBIN to Object File (.o) Command ---
|
||||
add_custom_command(
|
||||
OUTPUT ${{OBJECT_FILE}}
|
||||
COMMAND ${{CMAKE_LINKER}} -r -b binary -z noexecstack -o ${{OBJECT_FILE}} ${{FATBIN_FILE}}
|
||||
COMMAND ${{OBJCOPY_EXECUTABLE}} --rename-section .data=.rodata,alloc,load,readonly,data,contents
|
||||
${{OBJECT_FILE}}
|
||||
COMMAND ${{OBJCOPY_EXECUTABLE}}
|
||||
--redefine-sym ${{OBJCOPY_START_SYM}}=${{SYMBOL_START}}
|
||||
--redefine-sym ${{OBJCOPY_END_SYM}}=${{SYMBOL_END}}
|
||||
--redefine-sym ${{OBJCOPY_SIZE_SYM}}=${{SYMBOL_SIZE}}
|
||||
${{OBJECT_FILE}}
|
||||
DEPENDS ${{FATBIN_FILE}}
|
||||
)
|
||||
add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{OBJECT_FILE}})
|
||||
|
||||
# --- Add to a list for linking later ---
|
||||
set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE)
|
||||
set(KERNEL_OBJECT_FILES ${{KERNEL_OBJECT_FILES}} ${{OBJECT_FILE}} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
"""
|
||||
)
|
||||
if config.aot_inductor.cross_target_platform == "windows":
|
||||
exports_str = (
|
||||
"""LINK_FLAGS "/DEF:${CMAKE_CURRENT_SOURCE_DIR}/windows_symbol_exports.def" """
|
||||
if config.aot_inductor.dynamic_linkage
|
||||
else ""
|
||||
)
|
||||
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
# Make output use .pyd instead of .dll
|
||||
set_target_properties({self._target_name} PROPERTIES SUFFIX ".pyd" {exports_str})
|
||||
|
||||
set(KERNEL_TARGETS "")
|
||||
set(KERNEL_OBJECT_FILES "")
|
||||
# Function to compile ptx to cubin
|
||||
function(embed_gpu_kernel KERNEL_NAME PTX_FILE)
|
||||
set(CUBIN_BASENAME ${{KERNEL_NAME}}.cubin)
|
||||
set(CUBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{CUBIN_BASENAME}})
|
||||
# --- PTX to FATBIN Command & Target ---
|
||||
add_custom_command(
|
||||
OUTPUT ${{CUBIN_FILE}}
|
||||
COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --cubin ${{PTX_FILE}}
|
||||
-o ${{CUBIN_FILE}} ${{NVCC_GENCODE_FLAGS}}
|
||||
-gencode arch=compute_{current_arch},code=sm_{current_arch}
|
||||
DEPENDS ${{PTX_FILE}}
|
||||
)
|
||||
|
||||
add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{CUBIN_FILE}})
|
||||
set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE)
|
||||
"""
|
||||
)
|
||||
if config.aot_inductor.embed_kernel_binary:
|
||||
contents += textwrap.indent(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
# CUBIN → C++ array via xxd -i
|
||||
set(C_SRC_BASENAME ${KERNEL_NAME}.c)
|
||||
set(C_SRC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${C_SRC_BASENAME})
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${C_SRC_FILE}
|
||||
COMMAND xxd -i -n __${KERNEL_NAME}_start ${CUBIN_FILE} > ${C_SRC_FILE}
|
||||
DEPENDS ${CUBIN_FILE}
|
||||
COMMENT "Embedding ${CUBIN_FILE} as C array in ${C_SRC_FILE}"
|
||||
)
|
||||
|
||||
set_source_files_properties(${C_SRC_FILE}
|
||||
PROPERTIES GENERATED TRUE LANGUAGE C COMPILE_FLAGS "/TC") # /TC forces MSVC C mode
|
||||
|
||||
target_sources(model PRIVATE ${C_SRC_FILE})
|
||||
"""
|
||||
),
|
||||
" " * 4,
|
||||
)
|
||||
contents += textwrap.dedent(
|
||||
"""
|
||||
endfunction()
|
||||
"""
|
||||
)
|
||||
elif config.aot_inductor.embed_kernel_binary:
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
find_program(OBJCOPY_EXECUTABLE objcopy)
|
||||
if(NOT OBJCOPY_EXECUTABLE)
|
||||
message(FATAL_ERROR "objcopy not found. Cannot embed fatbin as object file")
|
||||
endif()
|
||||
|
||||
set(KERNEL_TARGETS "")
|
||||
set(KERNEL_OBJECT_FILES "")
|
||||
# Function to embed a single kernel
|
||||
function(embed_gpu_kernel KERNEL_NAME PTX_FILE)
|
||||
set(FATBIN_BASENAME ${{KERNEL_NAME}}.fatbin)
|
||||
set(FATBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{FATBIN_BASENAME}})
|
||||
set(OBJECT_BASENAME ${{KERNEL_NAME}}.fatbin.o)
|
||||
set(OBJECT_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{OBJECT_BASENAME}})
|
||||
|
||||
# --- Define UNIQUE C symbol names ---
|
||||
set(SYMBOL_START __${{KERNEL_NAME}}_start)
|
||||
set(SYMBOL_END __${{KERNEL_NAME}}_end)
|
||||
set(SYMBOL_SIZE __${{KERNEL_NAME}}_size)
|
||||
string(REGEX REPLACE "[^a-zA-Z0-9]" "_" MANGLED_BASENAME ${{FATBIN_FILE}})
|
||||
set(OBJCOPY_START_SYM _binary_${{MANGLED_BASENAME}}_start)
|
||||
set(OBJCOPY_END_SYM _binary_${{MANGLED_BASENAME}}_end)
|
||||
set(OBJCOPY_SIZE_SYM _binary_${{MANGLED_BASENAME}}_size)
|
||||
|
||||
# --- PTX to FATBIN Command & Target ---
|
||||
add_custom_command(
|
||||
OUTPUT ${{FATBIN_FILE}}
|
||||
COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}}
|
||||
-o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}}
|
||||
-gencode arch=compute_80,code=compute_80
|
||||
-gencode arch=compute_{current_arch},code=sm_{current_arch}
|
||||
DEPENDS ${{PTX_FILE}}
|
||||
)
|
||||
|
||||
# --- FATBIN to Object File (.o) Command ---
|
||||
add_custom_command(
|
||||
OUTPUT ${{OBJECT_FILE}}
|
||||
COMMAND ${{CMAKE_LINKER}} -r -b binary -z noexecstack -o ${{OBJECT_FILE}} ${{FATBIN_FILE}}
|
||||
COMMAND ${{OBJCOPY_EXECUTABLE}} --rename-section .data=.rodata,alloc,load,readonly,data,contents
|
||||
${{OBJECT_FILE}}
|
||||
COMMAND ${{OBJCOPY_EXECUTABLE}}
|
||||
--redefine-sym ${{OBJCOPY_START_SYM}}=${{SYMBOL_START}}
|
||||
--redefine-sym ${{OBJCOPY_END_SYM}}=${{SYMBOL_END}}
|
||||
--redefine-sym ${{OBJCOPY_SIZE_SYM}}=${{SYMBOL_SIZE}}
|
||||
${{OBJECT_FILE}}
|
||||
DEPENDS ${{FATBIN_FILE}}
|
||||
)
|
||||
add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{OBJECT_FILE}})
|
||||
|
||||
# --- Add to a list for linking later ---
|
||||
set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE)
|
||||
set(KERNEL_OBJECT_FILES ${{KERNEL_OBJECT_FILES}} ${{OBJECT_FILE}} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
with open(cmake_path, "w") as f:
|
||||
f.write(contents)
|
||||
@ -2196,15 +2274,15 @@ class CppBuilder:
|
||||
)
|
||||
|
||||
def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
|
||||
if (
|
||||
config.aot_inductor.compile_standalone
|
||||
and not config.test_configs.use_libtorch
|
||||
):
|
||||
# When compile_standalone is True, do not link with libtorch
|
||||
return
|
||||
|
||||
lflags = " ".join(self._build_option.get_ldflags())
|
||||
libs = " ".join(self._build_option.get_libraries())
|
||||
|
||||
if (
|
||||
config.aot_inductor.compile_with_torchstandalone
|
||||
and not config.test_configs.use_libtorch
|
||||
):
|
||||
libs += " ${TorchStandalone_LIBRARIES}"
|
||||
|
||||
contents = textwrap.dedent(
|
||||
f"""
|
||||
# Add linker flags
|
||||
@ -2224,7 +2302,7 @@ class CppBuilder:
|
||||
|
||||
def run_asm_build_object(src: str, target: str, cwd: str) -> None:
|
||||
def get_asm_compiler() -> str:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
ASM_CC = "ml64"
|
||||
else:
|
||||
ASM_CC = get_cpp_compiler()
|
||||
@ -2234,7 +2312,7 @@ def run_asm_build_object(src: str, target: str, cwd: str) -> None:
|
||||
return ASM_CC
|
||||
|
||||
def get_command_line(asm_cc: str, src: str, target: str) -> str:
|
||||
if _IS_WINDOWS:
|
||||
if is_target_windows():
|
||||
# Format reference:
|
||||
# https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170
|
||||
cmd = f"{asm_cc} {src} /c /Fo {target}" # codespell:ignore /Fo
|
||||
|
||||
@ -3487,10 +3487,18 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
|
||||
"""
|
||||
Ensures the configuration is internally consistent for standalone AOTInductor.
|
||||
|
||||
If `aot_inductor.compile_standalone` is set to True in the provided
|
||||
If `aot_inductor_mode.compile_standalone` is set to True in the provided
|
||||
`config_patches` (or falls back to the global config), this function ensures
|
||||
that the following configs are also enabled:
|
||||
that the following configs are also set:
|
||||
- `aot_inductor.package_cpp_only`
|
||||
- `aot_inductor.embed_kernel_binary`
|
||||
- `aot_inductor.emit_multi_arch_kernel`
|
||||
- `aot_inductor.link_libtorch=False`
|
||||
- `aot_inductor.dynamic_linkage=False`
|
||||
|
||||
If `aot_inductor.dynamic_linkage` is set to False in the provided
|
||||
`config_patches` (or falls back to the global config):
|
||||
- `aot_inductor.model_name_for_generated_files` is default to "aoti_model" if not set.
|
||||
|
||||
Args:
|
||||
config_patches (dict[str, Any]): A dictionary of user-provided config
|
||||
@ -3512,7 +3520,8 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
|
||||
)
|
||||
|
||||
compile_standalone = config_patches.get(
|
||||
"aot_inductor.compile_standalone", config.aot_inductor.compile_standalone
|
||||
"aot_inductor_mode.compile_standalone",
|
||||
config.aot_inductor_mode.compile_standalone,
|
||||
)
|
||||
# Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
|
||||
config_patches = config_patches.copy()
|
||||
@ -3525,6 +3534,13 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
|
||||
patch_config(
|
||||
config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip
|
||||
)
|
||||
patch_config(config_patches, "aot_inductor.dynamic_linkage", False)
|
||||
patch_config(config_patches, "aot_inductor.compile_with_torchstandalone", True)
|
||||
|
||||
dynamic_linkage = config_patches.get(
|
||||
"aot_inductor.dynamic_linkage", config.aot_inductor.dynamic_linkage
|
||||
)
|
||||
if not dynamic_linkage:
|
||||
patch_config(
|
||||
config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
|
||||
)
|
||||
|
||||
31
torch/csrc/inductor/aoti_runtime/windows_symbol_exports.def
Normal file
31
torch/csrc/inductor/aoti_runtime/windows_symbol_exports.def
Normal file
@ -0,0 +1,31 @@
|
||||
LIBRARY {config.aot_inductor.model_name_for_generated_files}
|
||||
EXPORTS
|
||||
AOTInductorModelContainerCreate
|
||||
AOTInductorModelContainerCreateWithDevice
|
||||
AOTInductorModelContainerRun
|
||||
AOTInductorModelContainerDelete
|
||||
AOTInductorModelContainerRunSingleThreaded
|
||||
AOTInductorModelContainerGetNumConstants
|
||||
AOTInductorModelContainerGetConstantName
|
||||
AOTInductorModelContainerGetConstantOriginalFQN
|
||||
AOTInductorModelContainerGetConstantFromFolded
|
||||
AOTInductorModelContainerGetConstantType
|
||||
AOTInductorModelContainerGetConstantDtype
|
||||
AOTInductorModelContainerGetConstantDataSize
|
||||
AOTInductorModelContainerExtractConstantsMap
|
||||
AOTInductorModelContainerUpdateUserManagedConstantBuffer
|
||||
AOTInductorModelContainerUpdateConstantBuffer
|
||||
AOTInductorModelContainerUpdateInactiveConstantBuffer
|
||||
AOTInductorModelContainerFreeInactiveConstantBuffer
|
||||
AOTInductorModelContainerRunConstantFolding
|
||||
AOTInductorModelContainerSwapConstantBuffer
|
||||
AOTInductorModelContainerGetNumInputs
|
||||
AOTInductorModelContainerGetInputName
|
||||
AOTInductorModelContainerGetNumOutputs
|
||||
AOTInductorModelContainerGetOutputName
|
||||
AOTInductorModelCreate
|
||||
AOTInductorModelRun
|
||||
AOTInductorModelUpdateConstantsMap
|
||||
AOTInductorModelDelete
|
||||
AOTInductorModelGetNumOutputs
|
||||
AOTInductorModelContainerGetCallSpec
|
||||
@ -362,7 +362,7 @@ class _ExportPackage:
|
||||
"always_keep_tensor_constants": True,
|
||||
# we'll change this back to False once we enable weight deduping for standalone mode
|
||||
"aot_inductor.package_constants_in_so": standalone,
|
||||
"aot_inductor.compile_standalone": standalone,
|
||||
"aot_inductor_mode.compile_standalone": standalone,
|
||||
}
|
||||
aoti_files_map = {}
|
||||
model_names = []
|
||||
|
||||
Reference in New Issue
Block a user