mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			ciflow/tru
			...
			aoti_targe
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 53f6698cc5 | 
							
								
								
									
										51
									
								
								run.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								run.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Model(torch.nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.fc1 = torch.nn.Linear(10, 16)
 | 
			
		||||
        self.relu = torch.nn.ReLU()
 | 
			
		||||
        self.fc2 = torch.nn.Linear(16, 1)
 | 
			
		||||
        self.sigmoid = torch.nn.Sigmoid()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.fc1(x)
 | 
			
		||||
        x = self.relu(x)
 | 
			
		||||
        x = self.fc2(x)
 | 
			
		||||
        x = self.sigmoid(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# rm -r /tmp/torchinductor_shangdiy/
 | 
			
		||||
with torch.no_grad():
 | 
			
		||||
    device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
    model = Model().to(device=device)
 | 
			
		||||
    example_inputs = (torch.randn(8, 10, device=device),)
 | 
			
		||||
    batch_dim = torch.export.Dim("batch", min=1, max=1024)
 | 
			
		||||
    # [Optional] Specify the first dimension of the input x as dynamic.
 | 
			
		||||
    exported = torch.export.export(
 | 
			
		||||
        model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
 | 
			
		||||
    )
 | 
			
		||||
    # [Note] In this example we directly feed the exported module to aoti_compile_and_package.
 | 
			
		||||
    # Depending on your use case, e.g. if your training platform and inference platform
 | 
			
		||||
    # are different, you may choose to save the exported model using torch.export.save and
 | 
			
		||||
    # then load it back using torch.export.load on your inference platform to run AOT compilation.
 | 
			
		||||
    output_path = torch._inductor.aoti_compile_and_package(
 | 
			
		||||
        exported,
 | 
			
		||||
        # [Optional] Specify the generated shared library path. If not specified,
 | 
			
		||||
        # the generated artifact is stored in your system temp directory.
 | 
			
		||||
        package_path=os.path.join(os.getcwd(), "model3.pt2"),
 | 
			
		||||
        inductor_configs={
 | 
			
		||||
            "aot_inductor.package_cpp_only": True,  # optional, compile_standalone automatically makes it cpp_only
 | 
			
		||||
            "aot_inductor.cross_target_platform": "windows",
 | 
			
		||||
            "aot_inductor.compile_standalone": True,
 | 
			
		||||
            "aot_inductor.dynamic_linkage": True,
 | 
			
		||||
            "max_autotune": True,
 | 
			
		||||
            "max_autotune_gemm_backends": "TRITON,CPP",
 | 
			
		||||
            "max_autotune_conv_backends": "TRITON,CPP",
 | 
			
		||||
            "aot_inductor.model_name_for_generated_files": "model",
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							@ -1693,6 +1693,7 @@ def main() -> None:
 | 
			
		||||
        "_inductor/codegen/*.h",
 | 
			
		||||
        "_inductor/codegen/aoti_runtime/*.h",
 | 
			
		||||
        "_inductor/codegen/aoti_runtime/*.cpp",
 | 
			
		||||
        "_inductor/codegen/aoti_runtime/windows_symbol_exports.def",
 | 
			
		||||
        "_inductor/script.ld",
 | 
			
		||||
        "_inductor/kernel/flex/templates/*.jinja",
 | 
			
		||||
        "_export/serde/*.yaml",
 | 
			
		||||
 | 
			
		||||
@ -7189,14 +7189,18 @@ class AOTInductorLoggingTest(LoggingTestCase):
 | 
			
		||||
 | 
			
		||||
class TestAOTInductorConfig(TestCase):
 | 
			
		||||
    def test_no_compile_standalone(self):
 | 
			
		||||
        with config.patch({"aot_inductor.compile_standalone": False}):
 | 
			
		||||
        with config.patch({"aot_inductor_mode.compile_standalone": False}):
 | 
			
		||||
            result = maybe_aoti_standalone_config({})
 | 
			
		||||
            self.assertEqual(result, {})
 | 
			
		||||
 | 
			
		||||
    def test_compile_standalone_sets_package_cpp(self):
 | 
			
		||||
        result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True})
 | 
			
		||||
        result = maybe_aoti_standalone_config(
 | 
			
		||||
            {
 | 
			
		||||
                "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(result["aot_inductor.package_cpp_only"], True)
 | 
			
		||||
        self.assertEqual(result["aot_inductor.compile_standalone"], True)
 | 
			
		||||
        self.assertEqual(result["aot_inductor_mode.compile_standalone"], True)
 | 
			
		||||
        self.assertEqual(result["aot_inductor.embed_kernel_binary"], True)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            result["aot_inductor.emit_multi_arch_kernel"], not torch.version.hip
 | 
			
		||||
@ -7207,7 +7211,7 @@ class TestAOTInductorConfig(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_compile_standalone_explicit_set(self):
 | 
			
		||||
        patches = {
 | 
			
		||||
            "aot_inductor.compile_standalone": True,
 | 
			
		||||
            "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
            "aot_inductor.package_cpp_only": True,
 | 
			
		||||
            "aot_inductor.embed_kernel_binary": True,
 | 
			
		||||
            "aot_inductor.emit_multi_arch_kernel": not torch.version.hip,
 | 
			
		||||
@ -7218,7 +7222,7 @@ class TestAOTInductorConfig(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_compile_standalone_package_cpp_false_raises(self):
 | 
			
		||||
        patches = {
 | 
			
		||||
            "aot_inductor.compile_standalone": True,
 | 
			
		||||
            "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
            "aot_inductor.package_cpp_only": False,
 | 
			
		||||
        }
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
@ -7226,7 +7230,7 @@ class TestAOTInductorConfig(TestCase):
 | 
			
		||||
 | 
			
		||||
        with config.patch({"aot_inductor.package_cpp_only": False}):
 | 
			
		||||
            patches = {
 | 
			
		||||
                "aot_inductor.compile_standalone": True,
 | 
			
		||||
                "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
            }
 | 
			
		||||
            with self.assertRaises(RuntimeError):
 | 
			
		||||
                maybe_aoti_standalone_config(patches)
 | 
			
		||||
 | 
			
		||||
@ -393,7 +393,7 @@ class TestAOTInductorPackage(TestCase):
 | 
			
		||||
 | 
			
		||||
            # Test compilation when no name is passed in
 | 
			
		||||
            options = {
 | 
			
		||||
                "aot_inductor.compile_standalone": True,
 | 
			
		||||
                "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
            }
 | 
			
		||||
            with (
 | 
			
		||||
                tempfile.TemporaryDirectory() as tmp_dir,
 | 
			
		||||
@ -407,7 +407,7 @@ class TestAOTInductorPackage(TestCase):
 | 
			
		||||
 | 
			
		||||
            # Test compilation when model name is passed in
 | 
			
		||||
            options = {
 | 
			
		||||
                "aot_inductor.compile_standalone": True,
 | 
			
		||||
                "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
                "aot_inductor.model_name_for_generated_files": "linear",
 | 
			
		||||
            }
 | 
			
		||||
            with (
 | 
			
		||||
@ -422,7 +422,7 @@ class TestAOTInductorPackage(TestCase):
 | 
			
		||||
 | 
			
		||||
            # test invalid model name
 | 
			
		||||
            options = {
 | 
			
		||||
                "aot_inductor.compile_standalone": True,
 | 
			
		||||
                "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
                "aot_inductor.model_name_for_generated_files": "linear/linear",
 | 
			
		||||
            }
 | 
			
		||||
            with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
 | 
			
		||||
@ -448,7 +448,7 @@ class TestAOTInductorPackage(TestCase):
 | 
			
		||||
 | 
			
		||||
            # Test compilation when model name is passed in
 | 
			
		||||
            options = {
 | 
			
		||||
                "aot_inductor.compile_standalone": True,
 | 
			
		||||
                "aot_inductor_mode.compile_standalone": True,
 | 
			
		||||
                "aot_inductor.model_name_for_generated_files": "cos",
 | 
			
		||||
            }
 | 
			
		||||
            with (
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										91
									
								
								test/inductor/test_aot_inductor_windows.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								test/inductor/test_aot_inductor_windows.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
			
		||||
# Owner(s): ["module: inductor"]
 | 
			
		||||
import sys
 | 
			
		||||
import tempfile
 | 
			
		||||
import zipfile
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch._inductor.config
 | 
			
		||||
from torch._inductor.test_case import TestCase
 | 
			
		||||
from torch.testing import FileCheck
 | 
			
		||||
from torch.testing._internal.inductor_utils import HAS_GPU, requires_gpu
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Simple(torch.nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.fc1 = torch.nn.Linear(10, 16)
 | 
			
		||||
        self.relu = torch.nn.ReLU()
 | 
			
		||||
        self.fc2 = torch.nn.Linear(16, 1)
 | 
			
		||||
        self.sigmoid = torch.nn.Sigmoid()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.fc1(x)
 | 
			
		||||
        x = self.relu(x)
 | 
			
		||||
        x = self.fc2(x)
 | 
			
		||||
        x = self.sigmoid(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestAOTInductorWindowsCrossCompilation(TestCase):
 | 
			
		||||
    @requires_gpu()
 | 
			
		||||
    def test_simple_cpp_only(self):
 | 
			
		||||
        # rm -r /tmp/torchinductor_shangdiy/
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            device = "cuda"
 | 
			
		||||
            model = Simple().to(device=device)
 | 
			
		||||
            example_inputs = (torch.randn(8, 10, device=device),)
 | 
			
		||||
            batch_dim = torch.export.Dim("batch", min=1, max=1024)
 | 
			
		||||
            exported = torch.export.export(
 | 
			
		||||
                model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
 | 
			
		||||
            )
 | 
			
		||||
            package_path = torch._inductor.aoti_compile_and_package(
 | 
			
		||||
                exported,
 | 
			
		||||
                inductor_configs={
 | 
			
		||||
                    "aot_inductor.model_name_for_generated_files": "model",
 | 
			
		||||
                    "aot_inductor.package_cpp_only": True,
 | 
			
		||||
                    "aot_inductor.cross_target_platform": "windows",
 | 
			
		||||
                    "aot_inductor.link_libtorch": False,
 | 
			
		||||
                    # no fallback ops
 | 
			
		||||
                    "max_autotune": True,
 | 
			
		||||
                    "max_autotune_gemm_backends": "TRITON,CPP",
 | 
			
		||||
                    "max_autotune_conv_backends": "TRITON,CPP",
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            with tempfile.TemporaryDirectory() as tmpdir:
 | 
			
		||||
                with zipfile.ZipFile(package_path, "r") as zf:
 | 
			
		||||
                    zf.extractall(tmpdir)
 | 
			
		||||
 | 
			
		||||
                makefile = open(
 | 
			
		||||
                    f"{tmpdir}/model.wrapper/data/aotinductor/model/CMakeLists.txt"
 | 
			
		||||
                )
 | 
			
		||||
                makefile_content = makefile.read()
 | 
			
		||||
 | 
			
		||||
                FileCheck().check("add_library(model SHARED)").check(
 | 
			
		||||
                    "target_compile_definitions(model PRIVATE NOMINMAX "
 | 
			
		||||
                ).check("USE_CUDA").check("target_compile_options(model").check(
 | 
			
		||||
                    """set_target_properties(model PROPERTIES SUFFIX ".pyd" """
 | 
			
		||||
                ).check(
 | 
			
		||||
                    """LINK_FLAGS "/DEF:${CMAKE_CURRENT_SOURCE_DIR}/windows_symbol_exports.def" )"""
 | 
			
		||||
                ).check(
 | 
			
		||||
                    "target_sources(model PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/model.wrapper.cpp)"
 | 
			
		||||
                ).check(
 | 
			
		||||
                    "target_sources(model PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/model_consts.weights.cpp)"
 | 
			
		||||
                ).check("embed_gpu_kernel(").check(
 | 
			
		||||
                    "add_dependencies(model ${KERNEL_TARGETS})"
 | 
			
		||||
                ).check(
 | 
			
		||||
                    "target_link_libraries(model PRIVATE ${KERNEL_OBJECT_FILES})"
 | 
			
		||||
                ).check(
 | 
			
		||||
                    "target_link_options(model PRIVATE )"  # no libtorch
 | 
			
		||||
                ).check("target_link_libraries(model PRIVATE CUDA::cudart cuda)").run(
 | 
			
		||||
                    makefile_content
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # TODO: actually compile the package in the test later in windows CI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from torch._inductor.test_case import run_tests
 | 
			
		||||
 | 
			
		||||
    if HAS_GPU or sys.platform == "darwin":
 | 
			
		||||
        run_tests(needs="filelock")
 | 
			
		||||
@ -75,6 +75,7 @@ from torch._inductor.cpp_builder import (
 | 
			
		||||
    get_compiler_version_info,
 | 
			
		||||
    get_ld_and_objcopy,
 | 
			
		||||
    get_name_and_dir_from_output_file_path,
 | 
			
		||||
    is_target_windows,
 | 
			
		||||
    normalize_path_separator,
 | 
			
		||||
    run_asm_build_object,
 | 
			
		||||
)
 | 
			
		||||
@ -142,7 +143,6 @@ if TYPE_CHECKING:
 | 
			
		||||
    from .utils import InputType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_IS_WINDOWS = sys.platform == "win32"
 | 
			
		||||
LOCK_TIMEOUT = 600
 | 
			
		||||
 | 
			
		||||
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
 | 
			
		||||
@ -393,7 +393,7 @@ class WritableTempFile:
 | 
			
		||||
        try:
 | 
			
		||||
            os.unlink(self.temp_file.name)
 | 
			
		||||
        except OSError as e:
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                # On Windows, some case temp file is opened and fail to unlink. Need to ignore it.
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
@ -447,7 +447,7 @@ def write_atomic(
 | 
			
		||||
    try:
 | 
			
		||||
        tmp_path.rename(target=path)
 | 
			
		||||
    except FileExistsError:
 | 
			
		||||
        if not _IS_WINDOWS:
 | 
			
		||||
        if not is_target_windows():
 | 
			
		||||
            raise
 | 
			
		||||
        # On Windows file exist is expected: https://docs.python.org/3/library/pathlib.html#pathlib.Path.rename
 | 
			
		||||
        # Below two lines code is equal to `tmp_path.rename(path)` on non-Windows OS.
 | 
			
		||||
@ -1610,7 +1610,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
 | 
			
		||||
@functools.cache
 | 
			
		||||
def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
 | 
			
		||||
    def get_module_ext_type() -> str:
 | 
			
		||||
        if _IS_WINDOWS:
 | 
			
		||||
        if is_target_windows():
 | 
			
		||||
            return ".pyd"
 | 
			
		||||
        else:
 | 
			
		||||
            return ".so"
 | 
			
		||||
@ -1777,7 +1777,7 @@ class AotCodeCompiler:
 | 
			
		||||
 | 
			
		||||
        header_code = ""
 | 
			
		||||
        header_path = ""
 | 
			
		||||
        if config.aot_inductor.compile_standalone:
 | 
			
		||||
        if not config.aot_inductor.dynamic_linkage:
 | 
			
		||||
            # to link statically, we also need a header file
 | 
			
		||||
            with open(
 | 
			
		||||
                os.path.join(
 | 
			
		||||
@ -1788,7 +1788,7 @@ class AotCodeCompiler:
 | 
			
		||||
                    "model.h",
 | 
			
		||||
                )
 | 
			
		||||
            ) as f:
 | 
			
		||||
                # model_name_for_generated_files is guaranteed to be non-empty when compile_standalone
 | 
			
		||||
                # model_name_for_generated_files is guaranteed to be non-empty when dynamic_linkage is False
 | 
			
		||||
                model_class_name = config.aot_inductor.model_name_for_generated_files
 | 
			
		||||
                class_name = f"AOTInductorModel{model_class_name}"
 | 
			
		||||
                header_code = f.read()
 | 
			
		||||
@ -1827,7 +1827,7 @@ class AotCodeCompiler:
 | 
			
		||||
            generated_files.append(wrapper_path)
 | 
			
		||||
            if not config.aot_inductor.package_cpp_only:
 | 
			
		||||
                generated_files.append(kernel_path)
 | 
			
		||||
            if config.aot_inductor.compile_standalone:
 | 
			
		||||
            if not config.aot_inductor.dynamic_linkage:
 | 
			
		||||
                generated_files.append(header_path)
 | 
			
		||||
 | 
			
		||||
        output_code_log.info("Wrapper code written to: %s", wrapper_path)
 | 
			
		||||
@ -1850,7 +1850,7 @@ class AotCodeCompiler:
 | 
			
		||||
            },
 | 
			
		||||
            payload_fn=lambda: kernel_code,
 | 
			
		||||
        )
 | 
			
		||||
        if config.aot_inductor.compile_standalone:
 | 
			
		||||
        if not config.aot_inductor.dynamic_linkage:
 | 
			
		||||
            output_code_log.info("Header code written to: %s", header_path)
 | 
			
		||||
            trace_structured(
 | 
			
		||||
                "graph_dump",
 | 
			
		||||
@ -1872,9 +1872,12 @@ class AotCodeCompiler:
 | 
			
		||||
            specified_sub_dir.mkdir(exist_ok=True)
 | 
			
		||||
        cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt")
 | 
			
		||||
 | 
			
		||||
        def _compile_consts(consts: bytes, platform: str) -> str:
 | 
			
		||||
        def _compile_consts(consts: bytes, platform: str) -> tuple[str, Optional[str]]:
 | 
			
		||||
            # Load from aot_inductor, and update the value on demand.
 | 
			
		||||
            use_asm_build: bool = config.aot_inductor.use_consts_asm_build
 | 
			
		||||
            use_asm_build: bool = (
 | 
			
		||||
                config.aot_inductor.use_consts_asm_build
 | 
			
		||||
                and config.aot_inductor.cross_target_platform != "windows"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if platform == "linux":
 | 
			
		||||
                if graph.mutated_buffers & OrderedSet(graph.constants.keys()):
 | 
			
		||||
@ -1976,7 +1979,7 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
 | 
			
		||||
                    Linux: Added '-pedantic' to disable zero-sized arrays in C++ compiler
 | 
			
		||||
                    Windows: MSVC naturally rejects zero-sized arrays by default
 | 
			
		||||
                """
 | 
			
		||||
                if _IS_WINDOWS:
 | 
			
		||||
                if is_target_windows():
 | 
			
		||||
                    # Windows ml64 is max support align to 16, but it is no effect to zero size data.
 | 
			
		||||
                    asm_code = """
 | 
			
		||||
option casemap:none
 | 
			
		||||
@ -2020,7 +2023,7 @@ end
 | 
			
		||||
                consts_code,
 | 
			
		||||
                code_ext,
 | 
			
		||||
                specified_dir=str(specified_sub_dir),
 | 
			
		||||
                key=config.aot_inductor.model_name_for_generated_files,
 | 
			
		||||
                key=f"{config.aot_inductor.model_name_for_generated_files}_consts",
 | 
			
		||||
            )
 | 
			
		||||
            consts_s = Path(consts_s)
 | 
			
		||||
            object_build_options = CppTorchDeviceOptions(
 | 
			
		||||
@ -2038,7 +2041,7 @@ end
 | 
			
		||||
            consts_o = object_builder.get_target_file_path()
 | 
			
		||||
            if use_asm_build is False and is_zero_size_consts:
 | 
			
		||||
                run_asm_build_object(str(consts_s), consts_o, str(consts_s.parent))
 | 
			
		||||
            else:
 | 
			
		||||
            elif config.aot_inductor.cross_target_platform != "windows":
 | 
			
		||||
                object_builder.build()
 | 
			
		||||
 | 
			
		||||
            if is_large_consts and use_asm_build:
 | 
			
		||||
@ -2058,10 +2061,12 @@ end
 | 
			
		||||
                        rc = f.write(consts[pos:])
 | 
			
		||||
                        pos += rc
 | 
			
		||||
 | 
			
		||||
            # Remove the .S file to save space
 | 
			
		||||
            os.remove(consts_s)
 | 
			
		||||
            if config.aot_inductor.cross_target_platform != "windows":
 | 
			
		||||
                # Remove the .S file to save space
 | 
			
		||||
                os.remove(consts_s)
 | 
			
		||||
                return consts_o, None
 | 
			
		||||
 | 
			
		||||
            return consts_o
 | 
			
		||||
            return consts_o, str(consts_s)
 | 
			
		||||
 | 
			
		||||
        from torch.utils._filelock import FileLock
 | 
			
		||||
 | 
			
		||||
@ -2199,7 +2204,7 @@ end
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # potentially, precompile the AOT header for this device
 | 
			
		||||
            if config.aot_inductor.precompile_headers and not _IS_WINDOWS:
 | 
			
		||||
            if config.aot_inductor.precompile_headers and not is_target_windows():
 | 
			
		||||
                header_file = _get_cpp_wrapper_header(
 | 
			
		||||
                    device_type, aot_mode=graph.aot_mode
 | 
			
		||||
                )
 | 
			
		||||
@ -2248,6 +2253,33 @@ end
 | 
			
		||||
                wrapper_builder.save_compile_cmd_to_cmake(cmake_path, device_type)
 | 
			
		||||
                wrapper_builder.save_src_to_cmake(cmake_path, wrapper_path)
 | 
			
		||||
                generated_files.append(cmake_path)
 | 
			
		||||
 | 
			
		||||
                if is_target_windows():
 | 
			
		||||
                    with open(
 | 
			
		||||
                        os.path.join(
 | 
			
		||||
                            os.path.dirname(os.path.dirname(__file__)),
 | 
			
		||||
                            "csrc",
 | 
			
		||||
                            "inductor",
 | 
			
		||||
                            "aoti_runtime",
 | 
			
		||||
                            "windows_symbol_exports.def",
 | 
			
		||||
                        )
 | 
			
		||||
                    ) as f:
 | 
			
		||||
                        # model_name_for_generated_files is guaranteed to be non-empty when dynamic_linkage is False
 | 
			
		||||
                        assert (
 | 
			
		||||
                            config.aot_inductor.model_name_for_generated_files
 | 
			
		||||
                            is not None
 | 
			
		||||
                        )
 | 
			
		||||
                        windows_symbol_exports = f.read().replace(
 | 
			
		||||
                            "{config.aot_inductor.model_name_for_generated_files}",
 | 
			
		||||
                            config.aot_inductor.model_name_for_generated_files,
 | 
			
		||||
                        )
 | 
			
		||||
                        _, expors_path = write(
 | 
			
		||||
                            windows_symbol_exports,
 | 
			
		||||
                            "def",
 | 
			
		||||
                            specified_dir=str(specified_sub_dir),
 | 
			
		||||
                            key="windows_symbol_exports",
 | 
			
		||||
                        )
 | 
			
		||||
                        generated_files.append(expors_path)
 | 
			
		||||
            else:
 | 
			
		||||
                try:
 | 
			
		||||
                    wrapper_builder.build()
 | 
			
		||||
@ -2268,7 +2300,7 @@ end
 | 
			
		||||
                )
 | 
			
		||||
                aot_constants = struct.pack("qq", consts_size + 8, magic_number)
 | 
			
		||||
 | 
			
		||||
            consts_o = _compile_consts(aot_constants, sys.platform)
 | 
			
		||||
            consts_o, consts_asm = _compile_consts(aot_constants, sys.platform)
 | 
			
		||||
            custom_obj_idx = 0
 | 
			
		||||
            # Note that custom_objs_config.json file is different from the model_constants_config.json file produced
 | 
			
		||||
            # in package_sigmoid(). The keys in custom_objs_config.json directly correspond to the arg name in extern
 | 
			
		||||
@ -2318,8 +2350,12 @@ end
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            cubins_o = []
 | 
			
		||||
            asm_files = []
 | 
			
		||||
            if not _IS_WINDOWS:
 | 
			
		||||
            asm_files = [
 | 
			
		||||
                value["asm"]
 | 
			
		||||
                for value in CudaKernelParamCache.cache.values()
 | 
			
		||||
                if "asm" in value
 | 
			
		||||
            ]
 | 
			
		||||
            if not is_target_windows():
 | 
			
		||||
                ld, objcopy = get_ld_and_objcopy(use_relative_path)
 | 
			
		||||
                kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {})
 | 
			
		||||
                for kernel_name, value in CudaKernelParamCache.cache.items():
 | 
			
		||||
@ -2328,9 +2364,6 @@ end
 | 
			
		||||
                        # than what the current graph uses
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    if asm_file := value["asm"]:
 | 
			
		||||
                        asm_files.append(asm_file)
 | 
			
		||||
 | 
			
		||||
                    cubin_file = value[get_cpp_wrapper_cubin_path_name()]
 | 
			
		||||
                    if (
 | 
			
		||||
                        config.aot_inductor.emit_multi_arch_kernel
 | 
			
		||||
@ -2338,7 +2371,7 @@ end
 | 
			
		||||
                    ):
 | 
			
		||||
                        current_arch = _nvcc_arch_as_compile_option()
 | 
			
		||||
                        cmd = (
 | 
			
		||||
                            f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} "
 | 
			
		||||
                            f"{_cuda_compiler()} -fatbin {value['asm']} -o {cubin_file} "
 | 
			
		||||
                            # Triton only allows generating PTX version as same as the current arch
 | 
			
		||||
                            f"-gencode arch=compute_{current_arch},code=compute_{current_arch} "
 | 
			
		||||
                            # Include SASS for the current specific arch
 | 
			
		||||
@ -2418,14 +2451,22 @@ end
 | 
			
		||||
                        f_weights.write(struct.pack("q", magic_number))
 | 
			
		||||
 | 
			
		||||
                    generated_files.append(weight_file)
 | 
			
		||||
                elif config.aot_inductor.cross_target_platform == "windows":
 | 
			
		||||
                    assert consts_asm is not None
 | 
			
		||||
                    generated_files.append(consts_asm)
 | 
			
		||||
                    so_builder.save_src_to_cmake(cmake_path, consts_asm)
 | 
			
		||||
                else:
 | 
			
		||||
                    # TODO: unify to always use mmap_weights
 | 
			
		||||
                    generated_files.append(consts_o)
 | 
			
		||||
                    so_builder.save_src_to_cmake(cmake_path, consts_o)
 | 
			
		||||
 | 
			
		||||
                if config.aot_inductor.emit_multi_arch_kernel:
 | 
			
		||||
                if (
 | 
			
		||||
                    config.aot_inductor.emit_multi_arch_kernel
 | 
			
		||||
                    or config.aot_inductor.cross_target_platform == "windows"
 | 
			
		||||
                ):
 | 
			
		||||
                    so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files)
 | 
			
		||||
                    generated_files.extend(asm_files)
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    obj_srcs = [*gpu_kernels_o, *cubins_o]
 | 
			
		||||
                    generated_files.extend(obj_srcs)
 | 
			
		||||
@ -2446,7 +2487,7 @@ end
 | 
			
		||||
                    def get_page_size() -> int:
 | 
			
		||||
                        # Don't use resource.getpagesize() on Windows, as it is a Unix specific package
 | 
			
		||||
                        # as seen in https://docs.python.org/2/library/resource.html
 | 
			
		||||
                        if _IS_WINDOWS:
 | 
			
		||||
                        if is_target_windows():
 | 
			
		||||
                            from ctypes import (  # type: ignore[attr-defined]
 | 
			
		||||
                                byref,
 | 
			
		||||
                                Structure,
 | 
			
		||||
@ -2574,7 +2615,7 @@ def _precompile_header(
 | 
			
		||||
    hashable_cmd_line: str,
 | 
			
		||||
    **compile_command: Any,
 | 
			
		||||
) -> str:
 | 
			
		||||
    assert not _IS_WINDOWS, (
 | 
			
		||||
    assert not is_target_windows(), (
 | 
			
		||||
        "CppBuilder does not currently support precompiling on Windows!"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -2759,7 +2800,7 @@ class CppCodeCache:
 | 
			
		||||
            lib = None
 | 
			
		||||
 | 
			
		||||
            # if requested, pre-compile any headers
 | 
			
		||||
            if config.cpp_cache_precompile_headers and not _IS_WINDOWS:
 | 
			
		||||
            if config.cpp_cache_precompile_headers and not is_target_windows():
 | 
			
		||||
                if header := cls._get_uncompiled_header(device_type):
 | 
			
		||||
                    main_build_option.precompiled_header = _precompile_header(
 | 
			
		||||
                        header,
 | 
			
		||||
@ -3737,11 +3778,9 @@ def _nvcc_host_compiler_options() -> list[str]:
 | 
			
		||||
 | 
			
		||||
def _nvcc_arch_as_compile_option() -> str:
 | 
			
		||||
    arch = cuda_env.get_cuda_arch()
 | 
			
		||||
    if arch == "90":
 | 
			
		||||
    if arch in ("90", "100", "120"):
 | 
			
		||||
        # Required by cutlass compilation.
 | 
			
		||||
        return "90a"
 | 
			
		||||
    if arch == "100":
 | 
			
		||||
        return "100a"
 | 
			
		||||
        return f"{arch}a"
 | 
			
		||||
    return arch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -60,9 +60,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
 | 
			
		||||
        # must be initialized prior to calling super().__init__()
 | 
			
		||||
        self.included_devices: OrderedSet[str] = OrderedSet()
 | 
			
		||||
        self.model_class_name_suffix = (
 | 
			
		||||
            config.aot_inductor.model_name_for_generated_files
 | 
			
		||||
            if config.aot_inductor.compile_standalone
 | 
			
		||||
            else ""
 | 
			
		||||
            ""
 | 
			
		||||
            if config.aot_inductor.dynamic_linkage
 | 
			
		||||
            else config.aot_inductor.model_name_for_generated_files
 | 
			
		||||
        )
 | 
			
		||||
        self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}"
 | 
			
		||||
 | 
			
		||||
@ -222,7 +222,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
 | 
			
		||||
        self.add_device_include(self.device)
 | 
			
		||||
 | 
			
		||||
        if V.graph.aot_mode:
 | 
			
		||||
            if not config.aot_inductor.compile_standalone:
 | 
			
		||||
            if config.aot_inductor.dynamic_linkage:
 | 
			
		||||
                with open(
 | 
			
		||||
                    os.path.join(
 | 
			
		||||
                        os.path.dirname(__file__), "aoti_runtime", "interface.cpp"
 | 
			
		||||
 | 
			
		||||
@ -1585,7 +1585,14 @@ class aot_inductor:
 | 
			
		||||
    # custom op libs that have implemented C shim wrappers
 | 
			
		||||
    custom_op_libs: Optional[list[str]] = None
 | 
			
		||||
 | 
			
		||||
    compile_standalone: bool = False
 | 
			
		||||
    # If set to "windows", we will compile from WSL and generate a C++ project file for
 | 
			
		||||
    # further Windows native compilation
 | 
			
		||||
    # Only works with package_cpp_only=True
 | 
			
		||||
    cross_target_platform: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # If package_cpp_only is True, whether cpp files will be compiled to a
 | 
			
		||||
    # dynamically linked library or static linked library
 | 
			
		||||
    dynamic_linkage: bool = True
 | 
			
		||||
 | 
			
		||||
    # Whether to enable link-time-optimization
 | 
			
		||||
    enable_lto = os.environ.get("AOT_INDUCTOR_ENABLE_LTO", "0") == "1"
 | 
			
		||||
@ -1601,6 +1608,22 @@ class aot_inductor:
 | 
			
		||||
    # TODO: should consolidate this flag with compile_standalone
 | 
			
		||||
    libtorch_free_headers: Optional[list[str]] = None
 | 
			
		||||
 | 
			
		||||
    # compile to TorchStandalone
 | 
			
		||||
    compile_with_torchstandalone: int = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# a convenient class that automatically sets a group of the configs in aot_inductor
 | 
			
		||||
# it should only control the flags in aot_inductor.
 | 
			
		||||
# it should not do anything else.
 | 
			
		||||
class aot_inductor_mode:
 | 
			
		||||
    # dynamic_linkage=False
 | 
			
		||||
    # link_libtorch=False
 | 
			
		||||
    # package_cpp_only=True
 | 
			
		||||
    # embed_kernel_binary=True
 | 
			
		||||
    # emit_multi_arch_kernel=True
 | 
			
		||||
    # compile_with_torchstandalone=True
 | 
			
		||||
    compile_standalone: bool = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class cuda:
 | 
			
		||||
    """Settings for cuda backend, today this consists of cutlass"""
 | 
			
		||||
 | 
			
		||||
@ -65,14 +65,27 @@ _LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld")
 | 
			
		||||
# initialize variables for compilation
 | 
			
		||||
_IS_LINUX = sys.platform.startswith("linux")
 | 
			
		||||
_IS_MACOS = sys.platform.startswith("darwin")
 | 
			
		||||
_IS_WINDOWS = sys.platform == "win32"
 | 
			
		||||
 | 
			
		||||
SUBPROCESS_DECODE_ARGS = ("utf-8",) if _IS_WINDOWS else ()
 | 
			
		||||
SUBPROCESS_DECODE_ARGS = (
 | 
			
		||||
    ("utf-8",)
 | 
			
		||||
    if sys.platform == "win32" or config.aot_inductor.cross_target_platform == "windows"
 | 
			
		||||
    else ()
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# =============================== toolchain ===============================
 | 
			
		||||
def is_target_windows() -> bool:
 | 
			
		||||
    return (
 | 
			
		||||
        sys.platform == "win32"
 | 
			
		||||
        or config.aot_inductor.cross_target_platform == "windows"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_compiling_on_windows() -> bool:
 | 
			
		||||
    return sys.platform == "win32"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(1)
 | 
			
		||||
def cpp_compiler_search(search: str) -> str:
 | 
			
		||||
    from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
 | 
			
		||||
@ -133,6 +146,9 @@ def check_compiler_exist_windows(compiler: str) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Check if compiler is ready, in case end user not activate MSVC environment.
 | 
			
		||||
    """
 | 
			
		||||
    if config.aot_inductor.cross_target_platform == "windows":
 | 
			
		||||
        # Do not check compiler if cross target platform is windows.
 | 
			
		||||
        return
 | 
			
		||||
    try:
 | 
			
		||||
        subprocess.check_output([compiler, "/help"], stderr=subprocess.STDOUT)
 | 
			
		||||
    except FileNotFoundError as exc:
 | 
			
		||||
@ -332,7 +348,7 @@ def check_msvc_cl_language_id(compiler: str) -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cpp_compiler() -> str:
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        compiler = os.environ.get("CXX", "cl")
 | 
			
		||||
        compiler = normalize_path_separator(compiler)
 | 
			
		||||
        check_compiler_exist_windows(compiler)
 | 
			
		||||
@ -349,7 +365,7 @@ def get_cpp_compiler() -> str:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_ld_and_objcopy(use_relative_path: bool) -> tuple[str, str]:
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        raise RuntimeError("Windows is not supported yet.")
 | 
			
		||||
    else:
 | 
			
		||||
        if config.is_fbcode():
 | 
			
		||||
@ -403,7 +419,7 @@ def _is_clang(cpp_compiler: str) -> bool:
 | 
			
		||||
    # Mac OS apple clang maybe named as gcc, need check compiler info.
 | 
			
		||||
    if sys.platform == "darwin":
 | 
			
		||||
        return _is_apple_clang(cpp_compiler)
 | 
			
		||||
    elif _IS_WINDOWS:
 | 
			
		||||
    elif is_target_windows():
 | 
			
		||||
        # clang suite have many compilers, and only clang-cl is supported.
 | 
			
		||||
        if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler):
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
@ -423,7 +439,7 @@ def _is_gcc(cpp_compiler: str) -> bool:
 | 
			
		||||
 | 
			
		||||
@functools.cache
 | 
			
		||||
def _is_msvc_cl(cpp_compiler: str) -> bool:
 | 
			
		||||
    if not _IS_WINDOWS:
 | 
			
		||||
    if not is_compiling_on_windows():
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -445,7 +461,7 @@ def _is_intel_compiler(cpp_compiler: str) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor.
 | 
			
		||||
        """
 | 
			
		||||
        min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0"
 | 
			
		||||
        min_version = "2024.2.1" if is_target_windows() else "0.0.0"
 | 
			
		||||
        if compiler_version < TorchVersion(min_version):
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Intel Compiler error: less than minimal version {min_version}."
 | 
			
		||||
@ -461,7 +477,7 @@ def _is_intel_compiler(cpp_compiler: str) -> bool:
 | 
			
		||||
        )
 | 
			
		||||
        is_intel_compiler = "Intel" in output_msg.splitlines()[0]
 | 
			
		||||
        if is_intel_compiler:
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                if re.search(r"((icx$)|(icx-cc$))", cpp_compiler):
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)."
 | 
			
		||||
@ -594,7 +610,7 @@ def run_compile_cmd(cmd_line: str, cwd: str) -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def normalize_path_separator(orig_path: str) -> str:
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        return orig_path.replace(os.sep, "/")
 | 
			
		||||
    return orig_path
 | 
			
		||||
 | 
			
		||||
@ -719,14 +735,14 @@ class BuildOptionsBase:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_warning_all_cflag(warning_all: bool = True) -> list[str]:
 | 
			
		||||
    if not _IS_WINDOWS:
 | 
			
		||||
    if not is_target_windows():
 | 
			
		||||
        return ["Wall"] if warning_all else []
 | 
			
		||||
    else:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        """
 | 
			
		||||
        On Windows, only c++20 can support `std::enable_if_t`.
 | 
			
		||||
        Ref: https://learn.microsoft.com/en-us/cpp/overview/cpp-conformance-improvements-2019?view=msvc-170#checking-for-abstract-class-types # noqa: B950
 | 
			
		||||
@ -741,7 +757,7 @@ def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        cflags = [
 | 
			
		||||
            "wd4819",
 | 
			
		||||
            "wd4251",
 | 
			
		||||
@ -778,7 +794,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
 | 
			
		||||
 | 
			
		||||
def _get_os_related_cpp_definitions(cpp_compiler: str) -> list[str]:
 | 
			
		||||
    os_definitions: list[str] = []
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        # On Windows, we need disable min/max macro to avoid C2589 error, as PyTorch CMake:
 | 
			
		||||
        # https://github.com/pytorch/pytorch/blob/9a41570199155eee92ebd28452a556075e34e1b4/CMakeLists.txt#L1118-L1119
 | 
			
		||||
        os_definitions.append("NOMINMAX")
 | 
			
		||||
@ -788,7 +804,7 @@ def _get_os_related_cpp_definitions(cpp_compiler: str) -> list[str]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_ffast_math_flags() -> list[str]:
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        flags = []
 | 
			
		||||
    else:
 | 
			
		||||
        # ffast-math is equivalent to these flags as in
 | 
			
		||||
@ -825,7 +841,7 @@ def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]:
 | 
			
		||||
    cflags: list[str] = []
 | 
			
		||||
    ldflags: list[str] = []
 | 
			
		||||
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        cflags = ["ZI", "_DEBUG"]
 | 
			
		||||
        ldflags = ["DEBUG", "ASSEMBLYDEBUG ", "OPT:REF", "OPT:ICF"]
 | 
			
		||||
    else:
 | 
			
		||||
@ -848,19 +864,19 @@ def _get_optimization_cflags(
 | 
			
		||||
 | 
			
		||||
    if b_debug_build:
 | 
			
		||||
        cflags, ldflags = _get_inductor_debug_symbol_cflags()
 | 
			
		||||
        if _IS_WINDOWS:
 | 
			
		||||
        if is_target_windows():
 | 
			
		||||
            cflags += ["Od", "Ob0", "Oy-"]
 | 
			
		||||
        else:
 | 
			
		||||
            cflags.append("O0")
 | 
			
		||||
    else:
 | 
			
		||||
        if _IS_WINDOWS:
 | 
			
		||||
        if is_target_windows():
 | 
			
		||||
            cflags = ["O1" if min_optimize else "O2"]
 | 
			
		||||
        else:
 | 
			
		||||
            cflags = [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"]
 | 
			
		||||
 | 
			
		||||
    cflags += _get_ffast_math_flags()
 | 
			
		||||
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        pass
 | 
			
		||||
    else:
 | 
			
		||||
        if sys.platform != "darwin":
 | 
			
		||||
@ -882,7 +898,7 @@ def _get_optimization_cflags(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_shared_cflags(do_link: bool) -> list[str]:
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        """
 | 
			
		||||
        MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
 | 
			
		||||
        https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170
 | 
			
		||||
@ -923,7 +939,11 @@ def get_cpp_options(
 | 
			
		||||
 | 
			
		||||
    definitions += _get_os_related_cpp_definitions(cpp_compiler)
 | 
			
		||||
 | 
			
		||||
    if not _IS_WINDOWS and config.aot_inductor.enable_lto and _is_clang(cpp_compiler):
 | 
			
		||||
    if (
 | 
			
		||||
        not is_target_windows()
 | 
			
		||||
        and config.aot_inductor.enable_lto
 | 
			
		||||
        and _is_clang(cpp_compiler)
 | 
			
		||||
    ):
 | 
			
		||||
        ldflags.append("fuse-ld=lld")
 | 
			
		||||
        ldflags.append("flto=thin")
 | 
			
		||||
 | 
			
		||||
@ -1005,7 +1025,7 @@ def _use_custom_generated_macros() -> list[str]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _use_fb_internal_macros() -> list[str]:
 | 
			
		||||
    if not _IS_WINDOWS:
 | 
			
		||||
    if not is_target_windows():
 | 
			
		||||
        if config.is_fbcode():
 | 
			
		||||
            fb_internal_macros = [
 | 
			
		||||
                "C10_USE_GLOG",
 | 
			
		||||
@ -1027,7 +1047,7 @@ def _setup_standard_sys_libs(
 | 
			
		||||
    cflags: list[str] = []
 | 
			
		||||
    include_dirs: list[str] = []
 | 
			
		||||
    passthrough_args: list[str] = []
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        return cflags, include_dirs, passthrough_args
 | 
			
		||||
 | 
			
		||||
    if config.is_fbcode():
 | 
			
		||||
@ -1098,7 +1118,7 @@ def _get_torch_related_args(
 | 
			
		||||
    else:
 | 
			
		||||
        libraries_dirs = []
 | 
			
		||||
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows() and config.aot_inductor.link_libtorch:
 | 
			
		||||
        libraries.append("sleef")
 | 
			
		||||
 | 
			
		||||
    return include_dirs, libraries_dirs, libraries
 | 
			
		||||
@ -1120,12 +1140,12 @@ def _get_python_include_dirs() -> list[str]:
 | 
			
		||||
def _get_python_related_args() -> tuple[list[str], list[str]]:
 | 
			
		||||
    python_include_dirs = _get_python_include_dirs()
 | 
			
		||||
    python_include_path = sysconfig.get_path(
 | 
			
		||||
        "include", scheme="nt" if _IS_WINDOWS else "posix_prefix"
 | 
			
		||||
        "include", scheme="nt" if is_target_windows() else "posix_prefix"
 | 
			
		||||
    )
 | 
			
		||||
    if python_include_path is not None:
 | 
			
		||||
        python_include_dirs.append(python_include_path)
 | 
			
		||||
 | 
			
		||||
    if _IS_WINDOWS:
 | 
			
		||||
    if is_target_windows():
 | 
			
		||||
        python_lib_path = [
 | 
			
		||||
            str(
 | 
			
		||||
                (
 | 
			
		||||
@ -1273,7 +1293,7 @@ def _get_openmp_args(
 | 
			
		||||
 | 
			
		||||
        # if openmp is still not available, we let the compiler to have a try,
 | 
			
		||||
        # and raise error together with instructions at compilation error later
 | 
			
		||||
    elif _IS_WINDOWS:
 | 
			
		||||
    elif is_target_windows():
 | 
			
		||||
        """
 | 
			
		||||
        On Windows, `clang` and `icx` have their specific openmp implenmention.
 | 
			
		||||
        And the openmp lib is in compiler's some sub-directory.
 | 
			
		||||
@ -1419,6 +1439,7 @@ def get_cpp_torch_options(
 | 
			
		||||
    ldflags = omp_ldflags
 | 
			
		||||
    libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths
 | 
			
		||||
    libraries = torch_libraries + omp_lib
 | 
			
		||||
 | 
			
		||||
    passthrough_args = (
 | 
			
		||||
        sys_libs_passthrough_args + isa_ps_args_build_flags + omp_passthrough_args
 | 
			
		||||
    )
 | 
			
		||||
@ -1585,6 +1606,8 @@ def get_cpp_torch_device_options(
 | 
			
		||||
                libraries += ["c10_hip", "torch_hip"]
 | 
			
		||||
            definitions.append(" __HIP_PLATFORM_AMD__")
 | 
			
		||||
        else:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                libraries += ["CUDA::cudart"]
 | 
			
		||||
            if config.is_fbcode() or not link_libtorch:
 | 
			
		||||
                libraries += ["cuda"]
 | 
			
		||||
            else:
 | 
			
		||||
@ -1597,7 +1620,7 @@ def get_cpp_torch_device_options(
 | 
			
		||||
            "Intel GPU driver is not properly installed, please follow the instruction "
 | 
			
		||||
            "in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support."
 | 
			
		||||
        )
 | 
			
		||||
        if _IS_WINDOWS:
 | 
			
		||||
        if is_target_windows():
 | 
			
		||||
            ze_root = os.getenv("LEVEL_ZERO_V1_SDK_PATH")
 | 
			
		||||
            if ze_root is None:
 | 
			
		||||
                raise OSError(xpu_error_string)
 | 
			
		||||
@ -1766,26 +1789,26 @@ class CppBuilder:
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __get_python_module_flags() -> tuple[str, str]:
 | 
			
		||||
        extension = ".pyd" if _IS_WINDOWS else ".so"
 | 
			
		||||
        output_flags = "/Fe" if _IS_WINDOWS else "-o"
 | 
			
		||||
        extension = ".pyd" if is_target_windows() else ".so"
 | 
			
		||||
        output_flags = "/Fe" if is_target_windows() else "-o"
 | 
			
		||||
        return extension, output_flags
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __get_object_flags() -> tuple[str, str]:
 | 
			
		||||
        extension = ".obj" if _IS_WINDOWS else ".o"
 | 
			
		||||
        output_flags = "/c /Fo" if _IS_WINDOWS else "-c -o"  # codespell:ignore
 | 
			
		||||
        extension = ".obj" if is_target_windows() else ".o"
 | 
			
		||||
        output_flags = "/c /Fo" if is_target_windows() else "-c -o"  # codespell:ignore
 | 
			
		||||
        return extension, output_flags
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __get_precompiled_header_flags() -> tuple[str, str]:
 | 
			
		||||
        extension = ".pch" if _IS_WINDOWS or not is_gcc() else ".gch"
 | 
			
		||||
        output_flags = "/Fp" if _IS_WINDOWS else "-o"
 | 
			
		||||
        extension = ".pch" if is_target_windows() or not is_gcc() else ".gch"
 | 
			
		||||
        output_flags = "/Fp" if is_target_windows() else "-o"
 | 
			
		||||
        return extension, output_flags
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __get_preprocessor_output_flags() -> tuple[str, str]:
 | 
			
		||||
        extension = ".i"
 | 
			
		||||
        output_flags = "/EP /P" if _IS_WINDOWS else "-E -P -o"
 | 
			
		||||
        output_flags = "/EP /P" if is_target_windows() else "-E -P -o"
 | 
			
		||||
        return extension, output_flags
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -1837,7 +1860,7 @@ class CppBuilder:
 | 
			
		||||
        # MSVC produces two files when precompiling: the actual .pch file, as well as an
 | 
			
		||||
        # object file which must be linked into the final library.  This class assumes
 | 
			
		||||
        # only one output file of note, so for now we'll error out here.
 | 
			
		||||
        assert not _IS_WINDOWS or not self._precompiling, (
 | 
			
		||||
        assert not is_target_windows() or not self._precompiling, (
 | 
			
		||||
            "Cannot currently precompile headers on Windows!"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -1856,7 +1879,7 @@ class CppBuilder:
 | 
			
		||||
            if self._use_relative_path
 | 
			
		||||
            else self._target_file
 | 
			
		||||
        )
 | 
			
		||||
        if _IS_WINDOWS:
 | 
			
		||||
        if is_target_windows():
 | 
			
		||||
            if self._preprocessing:
 | 
			
		||||
                # The target file name is automatically determined by MSVC.
 | 
			
		||||
                self._output = output_flags
 | 
			
		||||
@ -1883,19 +1906,19 @@ class CppBuilder:
 | 
			
		||||
            self._sources_args = " ".join(sources)
 | 
			
		||||
 | 
			
		||||
        for cflag in BuildOption.get_cflags():
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                self._cflags_args += f"/{cflag} "
 | 
			
		||||
            else:
 | 
			
		||||
                self._cflags_args += f"-{cflag} "
 | 
			
		||||
 | 
			
		||||
        for definition in BuildOption.get_definitions():
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                self._definitions_args += f"/D {definition} "
 | 
			
		||||
            else:
 | 
			
		||||
                self._definitions_args += f"-D {definition} "
 | 
			
		||||
 | 
			
		||||
        if precompiled_header := BuildOption.precompiled_header:
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    "Precompiled header support for MSVC is currently unavailable; ignoring %s",
 | 
			
		||||
                    precompiled_header,
 | 
			
		||||
@ -1904,25 +1927,25 @@ class CppBuilder:
 | 
			
		||||
                self._include_dirs_args = f"-include {precompiled_header} "
 | 
			
		||||
 | 
			
		||||
        for inc_dir in BuildOption.get_include_dirs():
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                self._include_dirs_args += f'/I "{inc_dir}" '
 | 
			
		||||
            else:
 | 
			
		||||
                self._include_dirs_args += f"-I{shlex.quote(inc_dir)} "
 | 
			
		||||
 | 
			
		||||
        for ldflag in BuildOption.get_ldflags():
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                self._ldflags_args += f"/{ldflag} "
 | 
			
		||||
            else:
 | 
			
		||||
                self._ldflags_args += f"-{ldflag} "
 | 
			
		||||
 | 
			
		||||
        for lib_dir in BuildOption.get_libraries_dirs():
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" '
 | 
			
		||||
            else:
 | 
			
		||||
                self._libraries_dirs_args += f"-L{lib_dir} "
 | 
			
		||||
 | 
			
		||||
        for lib in BuildOption.get_libraries():
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                self._libraries_args += f'"{lib}.lib" '
 | 
			
		||||
            else:
 | 
			
		||||
                self._libraries_args += f"-l{lib} "
 | 
			
		||||
@ -1943,7 +1966,7 @@ class CppBuilder:
 | 
			
		||||
            passthrough_args: str,
 | 
			
		||||
            output: str,
 | 
			
		||||
        ) -> str:
 | 
			
		||||
            if _IS_WINDOWS:
 | 
			
		||||
            if is_target_windows():
 | 
			
		||||
                # https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704
 | 
			
		||||
                # https://stackoverflow.com/a/31566153
 | 
			
		||||
                cmd = (
 | 
			
		||||
@ -2043,7 +2066,7 @@ class CppBuilder:
 | 
			
		||||
 | 
			
		||||
        definitions = " ".join(self._build_option.get_definitions())
 | 
			
		||||
        target_library_type = (
 | 
			
		||||
            "STATIC" if config.aot_inductor.compile_standalone else "SHARED"
 | 
			
		||||
            "STATIC" if not config.aot_inductor.dynamic_linkage else "SHARED"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        contents = textwrap.dedent(
 | 
			
		||||
@ -2058,10 +2081,7 @@ class CppBuilder:
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            not config.aot_inductor.compile_standalone
 | 
			
		||||
            or config.test_configs.use_libtorch
 | 
			
		||||
        ):
 | 
			
		||||
        if config.aot_inductor.link_libtorch or config.test_configs.use_libtorch:
 | 
			
		||||
            # When compile_standalone is True, the generated cpp project should
 | 
			
		||||
            # not use Torch. But for unit testing purpose, we need to use Torch here.
 | 
			
		||||
            contents += textwrap.dedent(
 | 
			
		||||
@ -2071,24 +2091,7 @@ class CppBuilder:
 | 
			
		||||
 | 
			
		||||
                """
 | 
			
		||||
            )
 | 
			
		||||
            # flags and macros here are mostly CPU specific. Not emitting them for GPU models
 | 
			
		||||
            # will make the generated CMake file more portable and won't really hurt performance.
 | 
			
		||||
            # NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may
 | 
			
		||||
            # be still needed.
 | 
			
		||||
            contents += textwrap.dedent(
 | 
			
		||||
                f"""
 | 
			
		||||
                # Add macro definitions
 | 
			
		||||
                target_compile_definitions({self._target_name} PRIVATE {definitions})
 | 
			
		||||
 | 
			
		||||
                # Add compile flags
 | 
			
		||||
                target_compile_options({self._target_name} PRIVATE {self._cflags_args})
 | 
			
		||||
 | 
			
		||||
                # Backend-specific flags
 | 
			
		||||
                target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
 | 
			
		||||
 | 
			
		||||
                """
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
        elif config.aot_inductor.compile_with_torchstandalone:
 | 
			
		||||
            # When compile_standalone is True, use TorchStandalone instead of Torch
 | 
			
		||||
            contents += textwrap.dedent(
 | 
			
		||||
                f"""
 | 
			
		||||
@ -2100,73 +2103,148 @@ class CppBuilder:
 | 
			
		||||
                """
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        contents += textwrap.dedent(
 | 
			
		||||
            f"""
 | 
			
		||||
            # Add macro definitions
 | 
			
		||||
            target_compile_definitions({self._target_name} PRIVATE {definitions})
 | 
			
		||||
 | 
			
		||||
            # Add compile flags
 | 
			
		||||
            target_compile_options({self._target_name} PRIVATE {self._cflags_args})
 | 
			
		||||
 | 
			
		||||
            # Backend-specific flags
 | 
			
		||||
            target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
 | 
			
		||||
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if device_type == "cuda" and torch.version.hip is None:
 | 
			
		||||
            from torch._inductor.codecache import _nvcc_arch_as_compile_option
 | 
			
		||||
 | 
			
		||||
            current_arch = _nvcc_arch_as_compile_option()
 | 
			
		||||
            contents += textwrap.dedent(
 | 
			
		||||
                f"""
 | 
			
		||||
                """
 | 
			
		||||
                enable_language(CUDA)
 | 
			
		||||
                set(CMAKE_CUDA_STANDARD 17)
 | 
			
		||||
                find_package(CUDAToolkit REQUIRED)
 | 
			
		||||
                target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}})
 | 
			
		||||
                target_compile_definitions({self._target_name} PRIVATE USE_CUDA)
 | 
			
		||||
                target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static)
 | 
			
		||||
 | 
			
		||||
                find_program(OBJCOPY_EXECUTABLE objcopy)
 | 
			
		||||
                if(NOT OBJCOPY_EXECUTABLE)
 | 
			
		||||
                    message(FATAL_ERROR "objcopy not found. Cannot embed fatbin as object file")
 | 
			
		||||
                endif()
 | 
			
		||||
 | 
			
		||||
                set(KERNEL_TARGETS "")
 | 
			
		||||
                set(KERNEL_OBJECT_FILES "")
 | 
			
		||||
                # Function to embed a single kernel
 | 
			
		||||
                function(embed_gpu_kernel KERNEL_NAME PTX_FILE)
 | 
			
		||||
                    set(FATBIN_BASENAME ${{KERNEL_NAME}}.fatbin)
 | 
			
		||||
                    set(FATBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{FATBIN_BASENAME}})
 | 
			
		||||
                    set(OBJECT_BASENAME ${{KERNEL_NAME}}.fatbin.o)
 | 
			
		||||
                    set(OBJECT_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{OBJECT_BASENAME}})
 | 
			
		||||
 | 
			
		||||
                    # --- Define UNIQUE C symbol names ---
 | 
			
		||||
                    set(SYMBOL_START __${{KERNEL_NAME}}_start)
 | 
			
		||||
                    set(SYMBOL_END __${{KERNEL_NAME}}_end)
 | 
			
		||||
                    set(SYMBOL_SIZE __${{KERNEL_NAME}}_size)
 | 
			
		||||
                    string(REGEX REPLACE "[^a-zA-Z0-9]" "_" MANGLED_BASENAME ${{FATBIN_FILE}})
 | 
			
		||||
                    set(OBJCOPY_START_SYM _binary_${{MANGLED_BASENAME}}_start)
 | 
			
		||||
                    set(OBJCOPY_END_SYM _binary_${{MANGLED_BASENAME}}_end)
 | 
			
		||||
                    set(OBJCOPY_SIZE_SYM _binary_${{MANGLED_BASENAME}}_size)
 | 
			
		||||
 | 
			
		||||
                    # --- PTX to FATBIN Command & Target ---
 | 
			
		||||
                    add_custom_command(
 | 
			
		||||
                        OUTPUT ${{FATBIN_FILE}}
 | 
			
		||||
                        COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}}
 | 
			
		||||
                                -gencode arch=compute_{current_arch},code=compute_{current_arch}
 | 
			
		||||
                                -gencode arch=compute_{current_arch},code=sm_{current_arch}
 | 
			
		||||
                        DEPENDS ${{PTX_FILE}}
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    # --- FATBIN to Object File (.o) Command ---
 | 
			
		||||
                    add_custom_command(
 | 
			
		||||
                        OUTPUT ${{OBJECT_FILE}}
 | 
			
		||||
                        COMMAND ${{CMAKE_LINKER}} -r -b binary -z noexecstack -o ${{OBJECT_FILE}} ${{FATBIN_FILE}}
 | 
			
		||||
                        COMMAND ${{OBJCOPY_EXECUTABLE}} --rename-section .data=.rodata,alloc,load,readonly,data,contents
 | 
			
		||||
                                ${{OBJECT_FILE}}
 | 
			
		||||
                        COMMAND ${{OBJCOPY_EXECUTABLE}}
 | 
			
		||||
                                --redefine-sym ${{OBJCOPY_START_SYM}}=${{SYMBOL_START}}
 | 
			
		||||
                                --redefine-sym ${{OBJCOPY_END_SYM}}=${{SYMBOL_END}}
 | 
			
		||||
                                --redefine-sym ${{OBJCOPY_SIZE_SYM}}=${{SYMBOL_SIZE}}
 | 
			
		||||
                                ${{OBJECT_FILE}}
 | 
			
		||||
                        DEPENDS ${{FATBIN_FILE}}
 | 
			
		||||
                    )
 | 
			
		||||
                    add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{OBJECT_FILE}})
 | 
			
		||||
 | 
			
		||||
                    # --- Add to a list for linking later ---
 | 
			
		||||
                    set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE)
 | 
			
		||||
                    set(KERNEL_OBJECT_FILES ${{KERNEL_OBJECT_FILES}} ${{OBJECT_FILE}} PARENT_SCOPE)
 | 
			
		||||
                endfunction()
 | 
			
		||||
 | 
			
		||||
                """
 | 
			
		||||
            )
 | 
			
		||||
            if config.aot_inductor.cross_target_platform == "windows":
 | 
			
		||||
                exports_str = (
 | 
			
		||||
                    """LINK_FLAGS "/DEF:${CMAKE_CURRENT_SOURCE_DIR}/windows_symbol_exports.def" """
 | 
			
		||||
                    if config.aot_inductor.dynamic_linkage
 | 
			
		||||
                    else ""
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                contents += textwrap.dedent(
 | 
			
		||||
                    f"""
 | 
			
		||||
                    # Make output use .pyd instead of .dll
 | 
			
		||||
                    set_target_properties({self._target_name} PROPERTIES SUFFIX ".pyd" {exports_str})
 | 
			
		||||
 | 
			
		||||
                    set(KERNEL_TARGETS "")
 | 
			
		||||
                    set(KERNEL_OBJECT_FILES "")
 | 
			
		||||
                    # Function to compile ptx to cubin
 | 
			
		||||
                    function(embed_gpu_kernel KERNEL_NAME PTX_FILE)
 | 
			
		||||
                        set(CUBIN_BASENAME ${{KERNEL_NAME}}.cubin)
 | 
			
		||||
                        set(CUBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{CUBIN_BASENAME}})
 | 
			
		||||
                        # --- PTX to FATBIN Command & Target ---
 | 
			
		||||
                        add_custom_command(
 | 
			
		||||
                            OUTPUT ${{CUBIN_FILE}}
 | 
			
		||||
                            COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --cubin ${{PTX_FILE}}
 | 
			
		||||
                                    -o ${{CUBIN_FILE}} ${{NVCC_GENCODE_FLAGS}}
 | 
			
		||||
                                    -gencode arch=compute_{current_arch},code=sm_{current_arch}
 | 
			
		||||
                            DEPENDS ${{PTX_FILE}}
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                        add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{CUBIN_FILE}})
 | 
			
		||||
                        set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE)
 | 
			
		||||
                        """
 | 
			
		||||
                )
 | 
			
		||||
                if config.aot_inductor.embed_kernel_binary:
 | 
			
		||||
                    contents += textwrap.indent(
 | 
			
		||||
                        textwrap.dedent(
 | 
			
		||||
                            """
 | 
			
		||||
                        # CUBIN → C++ array via xxd -i
 | 
			
		||||
                        set(C_SRC_BASENAME ${KERNEL_NAME}.c)
 | 
			
		||||
                        set(C_SRC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${C_SRC_BASENAME})
 | 
			
		||||
 | 
			
		||||
                        add_custom_command(
 | 
			
		||||
                            OUTPUT ${C_SRC_FILE}
 | 
			
		||||
                            COMMAND xxd -i -n __${KERNEL_NAME}_start ${CUBIN_FILE} > ${C_SRC_FILE}
 | 
			
		||||
                            DEPENDS ${CUBIN_FILE}
 | 
			
		||||
                            COMMENT "Embedding ${CUBIN_FILE} as C array in ${C_SRC_FILE}"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                        set_source_files_properties(${C_SRC_FILE}
 | 
			
		||||
                            PROPERTIES GENERATED TRUE LANGUAGE C COMPILE_FLAGS "/TC")  # /TC forces MSVC C mode
 | 
			
		||||
 | 
			
		||||
                        target_sources(model PRIVATE ${C_SRC_FILE})
 | 
			
		||||
                        """
 | 
			
		||||
                        ),
 | 
			
		||||
                        " " * 4,
 | 
			
		||||
                    )
 | 
			
		||||
                contents += textwrap.dedent(
 | 
			
		||||
                    """
 | 
			
		||||
                    endfunction()
 | 
			
		||||
                    """
 | 
			
		||||
                )
 | 
			
		||||
            elif config.aot_inductor.embed_kernel_binary:
 | 
			
		||||
                contents += textwrap.dedent(
 | 
			
		||||
                    f"""
 | 
			
		||||
                    find_program(OBJCOPY_EXECUTABLE objcopy)
 | 
			
		||||
                    if(NOT OBJCOPY_EXECUTABLE)
 | 
			
		||||
                        message(FATAL_ERROR "objcopy not found. Cannot embed fatbin as object file")
 | 
			
		||||
                    endif()
 | 
			
		||||
 | 
			
		||||
                    set(KERNEL_TARGETS "")
 | 
			
		||||
                    set(KERNEL_OBJECT_FILES "")
 | 
			
		||||
                    # Function to embed a single kernel
 | 
			
		||||
                    function(embed_gpu_kernel KERNEL_NAME PTX_FILE)
 | 
			
		||||
                        set(FATBIN_BASENAME ${{KERNEL_NAME}}.fatbin)
 | 
			
		||||
                        set(FATBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{FATBIN_BASENAME}})
 | 
			
		||||
                        set(OBJECT_BASENAME ${{KERNEL_NAME}}.fatbin.o)
 | 
			
		||||
                        set(OBJECT_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{OBJECT_BASENAME}})
 | 
			
		||||
 | 
			
		||||
                        # --- Define UNIQUE C symbol names ---
 | 
			
		||||
                        set(SYMBOL_START __${{KERNEL_NAME}}_start)
 | 
			
		||||
                        set(SYMBOL_END __${{KERNEL_NAME}}_end)
 | 
			
		||||
                        set(SYMBOL_SIZE __${{KERNEL_NAME}}_size)
 | 
			
		||||
                        string(REGEX REPLACE "[^a-zA-Z0-9]" "_" MANGLED_BASENAME ${{FATBIN_FILE}})
 | 
			
		||||
                        set(OBJCOPY_START_SYM _binary_${{MANGLED_BASENAME}}_start)
 | 
			
		||||
                        set(OBJCOPY_END_SYM _binary_${{MANGLED_BASENAME}}_end)
 | 
			
		||||
                        set(OBJCOPY_SIZE_SYM _binary_${{MANGLED_BASENAME}}_size)
 | 
			
		||||
 | 
			
		||||
                        # --- PTX to FATBIN Command & Target ---
 | 
			
		||||
                        add_custom_command(
 | 
			
		||||
                            OUTPUT ${{FATBIN_FILE}}
 | 
			
		||||
                            COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}}
 | 
			
		||||
                                    -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}}
 | 
			
		||||
                                    -gencode arch=compute_80,code=compute_80
 | 
			
		||||
                                    -gencode arch=compute_{current_arch},code=sm_{current_arch}
 | 
			
		||||
                            DEPENDS ${{PTX_FILE}}
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                        # --- FATBIN to Object File (.o) Command ---
 | 
			
		||||
                        add_custom_command(
 | 
			
		||||
                            OUTPUT ${{OBJECT_FILE}}
 | 
			
		||||
                            COMMAND ${{CMAKE_LINKER}} -r -b binary -z noexecstack -o ${{OBJECT_FILE}} ${{FATBIN_FILE}}
 | 
			
		||||
                            COMMAND ${{OBJCOPY_EXECUTABLE}} --rename-section .data=.rodata,alloc,load,readonly,data,contents
 | 
			
		||||
                                    ${{OBJECT_FILE}}
 | 
			
		||||
                            COMMAND ${{OBJCOPY_EXECUTABLE}}
 | 
			
		||||
                                    --redefine-sym ${{OBJCOPY_START_SYM}}=${{SYMBOL_START}}
 | 
			
		||||
                                    --redefine-sym ${{OBJCOPY_END_SYM}}=${{SYMBOL_END}}
 | 
			
		||||
                                    --redefine-sym ${{OBJCOPY_SIZE_SYM}}=${{SYMBOL_SIZE}}
 | 
			
		||||
                                    ${{OBJECT_FILE}}
 | 
			
		||||
                            DEPENDS ${{FATBIN_FILE}}
 | 
			
		||||
                        )
 | 
			
		||||
                        add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{OBJECT_FILE}})
 | 
			
		||||
 | 
			
		||||
                        # --- Add to a list for linking later ---
 | 
			
		||||
                        set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE)
 | 
			
		||||
                        set(KERNEL_OBJECT_FILES ${{KERNEL_OBJECT_FILES}} ${{OBJECT_FILE}} PARENT_SCOPE)
 | 
			
		||||
                    endfunction()
 | 
			
		||||
 | 
			
		||||
                    """
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        with open(cmake_path, "w") as f:
 | 
			
		||||
            f.write(contents)
 | 
			
		||||
@ -2196,15 +2274,15 @@ class CppBuilder:
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
 | 
			
		||||
        if (
 | 
			
		||||
            config.aot_inductor.compile_standalone
 | 
			
		||||
            and not config.test_configs.use_libtorch
 | 
			
		||||
        ):
 | 
			
		||||
            # When compile_standalone is True, do not link with libtorch
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        lflags = " ".join(self._build_option.get_ldflags())
 | 
			
		||||
        libs = " ".join(self._build_option.get_libraries())
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            config.aot_inductor.compile_with_torchstandalone
 | 
			
		||||
            and not config.test_configs.use_libtorch
 | 
			
		||||
        ):
 | 
			
		||||
            libs += " ${TorchStandalone_LIBRARIES}"
 | 
			
		||||
 | 
			
		||||
        contents = textwrap.dedent(
 | 
			
		||||
            f"""
 | 
			
		||||
            # Add linker flags
 | 
			
		||||
@ -2224,7 +2302,7 @@ class CppBuilder:
 | 
			
		||||
 | 
			
		||||
def run_asm_build_object(src: str, target: str, cwd: str) -> None:
 | 
			
		||||
    def get_asm_compiler() -> str:
 | 
			
		||||
        if _IS_WINDOWS:
 | 
			
		||||
        if is_target_windows():
 | 
			
		||||
            ASM_CC = "ml64"
 | 
			
		||||
        else:
 | 
			
		||||
            ASM_CC = get_cpp_compiler()
 | 
			
		||||
@ -2234,7 +2312,7 @@ def run_asm_build_object(src: str, target: str, cwd: str) -> None:
 | 
			
		||||
        return ASM_CC
 | 
			
		||||
 | 
			
		||||
    def get_command_line(asm_cc: str, src: str, target: str) -> str:
 | 
			
		||||
        if _IS_WINDOWS:
 | 
			
		||||
        if is_target_windows():
 | 
			
		||||
            # Format reference:
 | 
			
		||||
            # https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170
 | 
			
		||||
            cmd = f"{asm_cc} {src} /c /Fo {target}"  # codespell:ignore /Fo
 | 
			
		||||
 | 
			
		||||
@ -3487,10 +3487,18 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
 | 
			
		||||
    """
 | 
			
		||||
    Ensures the configuration is internally consistent for standalone AOTInductor.
 | 
			
		||||
 | 
			
		||||
    If `aot_inductor.compile_standalone` is set to True in the provided
 | 
			
		||||
    If `aot_inductor_mode.compile_standalone` is set to True in the provided
 | 
			
		||||
    `config_patches` (or falls back to the global config), this function ensures
 | 
			
		||||
    that the following configs are also enabled:
 | 
			
		||||
    that the following configs are also set:
 | 
			
		||||
        - `aot_inductor.package_cpp_only`
 | 
			
		||||
        - `aot_inductor.embed_kernel_binary`
 | 
			
		||||
        - `aot_inductor.emit_multi_arch_kernel`
 | 
			
		||||
        - `aot_inductor.link_libtorch=False`
 | 
			
		||||
        - `aot_inductor.dynamic_linkage=False`
 | 
			
		||||
 | 
			
		||||
    If `aot_inductor.dynamic_linkage` is set to False in the provided
 | 
			
		||||
    `config_patches` (or falls back to the global config):
 | 
			
		||||
        - `aot_inductor.model_name_for_generated_files` is default to "aoti_model" if not set.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        config_patches (dict[str, Any]): A dictionary of user-provided config
 | 
			
		||||
@ -3512,7 +3520,8 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    compile_standalone = config_patches.get(
 | 
			
		||||
        "aot_inductor.compile_standalone", config.aot_inductor.compile_standalone
 | 
			
		||||
        "aot_inductor_mode.compile_standalone",
 | 
			
		||||
        config.aot_inductor_mode.compile_standalone,
 | 
			
		||||
    )
 | 
			
		||||
    # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
 | 
			
		||||
    config_patches = config_patches.copy()
 | 
			
		||||
@ -3525,6 +3534,13 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
 | 
			
		||||
        patch_config(
 | 
			
		||||
            config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip
 | 
			
		||||
        )
 | 
			
		||||
        patch_config(config_patches, "aot_inductor.dynamic_linkage", False)
 | 
			
		||||
        patch_config(config_patches, "aot_inductor.compile_with_torchstandalone", True)
 | 
			
		||||
 | 
			
		||||
    dynamic_linkage = config_patches.get(
 | 
			
		||||
        "aot_inductor.dynamic_linkage", config.aot_inductor.dynamic_linkage
 | 
			
		||||
    )
 | 
			
		||||
    if not dynamic_linkage:
 | 
			
		||||
        patch_config(
 | 
			
		||||
            config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										31
									
								
								torch/csrc/inductor/aoti_runtime/windows_symbol_exports.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								torch/csrc/inductor/aoti_runtime/windows_symbol_exports.def
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,31 @@
 | 
			
		||||
LIBRARY {config.aot_inductor.model_name_for_generated_files}
 | 
			
		||||
EXPORTS
 | 
			
		||||
    AOTInductorModelContainerCreate
 | 
			
		||||
    AOTInductorModelContainerCreateWithDevice
 | 
			
		||||
    AOTInductorModelContainerRun
 | 
			
		||||
    AOTInductorModelContainerDelete
 | 
			
		||||
    AOTInductorModelContainerRunSingleThreaded
 | 
			
		||||
    AOTInductorModelContainerGetNumConstants
 | 
			
		||||
    AOTInductorModelContainerGetConstantName
 | 
			
		||||
    AOTInductorModelContainerGetConstantOriginalFQN
 | 
			
		||||
    AOTInductorModelContainerGetConstantFromFolded
 | 
			
		||||
    AOTInductorModelContainerGetConstantType
 | 
			
		||||
    AOTInductorModelContainerGetConstantDtype
 | 
			
		||||
    AOTInductorModelContainerGetConstantDataSize
 | 
			
		||||
    AOTInductorModelContainerExtractConstantsMap
 | 
			
		||||
    AOTInductorModelContainerUpdateUserManagedConstantBuffer
 | 
			
		||||
    AOTInductorModelContainerUpdateConstantBuffer
 | 
			
		||||
    AOTInductorModelContainerUpdateInactiveConstantBuffer
 | 
			
		||||
    AOTInductorModelContainerFreeInactiveConstantBuffer
 | 
			
		||||
    AOTInductorModelContainerRunConstantFolding
 | 
			
		||||
    AOTInductorModelContainerSwapConstantBuffer
 | 
			
		||||
    AOTInductorModelContainerGetNumInputs
 | 
			
		||||
    AOTInductorModelContainerGetInputName
 | 
			
		||||
    AOTInductorModelContainerGetNumOutputs
 | 
			
		||||
    AOTInductorModelContainerGetOutputName
 | 
			
		||||
    AOTInductorModelCreate
 | 
			
		||||
    AOTInductorModelRun
 | 
			
		||||
    AOTInductorModelUpdateConstantsMap
 | 
			
		||||
    AOTInductorModelDelete
 | 
			
		||||
    AOTInductorModelGetNumOutputs
 | 
			
		||||
    AOTInductorModelContainerGetCallSpec
 | 
			
		||||
@ -362,7 +362,7 @@ class _ExportPackage:
 | 
			
		||||
            "always_keep_tensor_constants": True,
 | 
			
		||||
            # we'll change this back to False once we enable weight deduping for standalone mode
 | 
			
		||||
            "aot_inductor.package_constants_in_so": standalone,
 | 
			
		||||
            "aot_inductor.compile_standalone": standalone,
 | 
			
		||||
            "aot_inductor_mode.compile_standalone": standalone,
 | 
			
		||||
        }
 | 
			
		||||
        aoti_files_map = {}
 | 
			
		||||
        model_names = []
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user