mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTI] Add more default options to compile_standalone (#158560)
Summary: When compiling for standalone, make embed_kernel_binary and emit_multi_arch_kernel default to True, and add a default name for model_name_for_generated_files to make the generated cpp project easier to understand. Also improved the weights object file naming to be more readable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158560 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
9e0473b566
commit
a991e285ae
@ -6646,11 +6646,19 @@ class TestAOTInductorConfig(TestCase):
|
||||
result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True})
|
||||
self.assertEqual(result["aot_inductor.package_cpp_only"], True)
|
||||
self.assertEqual(result["aot_inductor.compile_standalone"], True)
|
||||
self.assertEqual(result["aot_inductor.embed_kernel_binary"], True)
|
||||
self.assertEqual(result["aot_inductor.emit_multi_arch_kernel"], True)
|
||||
self.assertEqual(
|
||||
result["aot_inductor.model_name_for_generated_files"], "aoti_model"
|
||||
)
|
||||
|
||||
def test_compile_standalone_package_cpp_already_true(self):
|
||||
def test_compile_standalone_explicit_set(self):
|
||||
patches = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor.package_cpp_only": True,
|
||||
"aot_inductor.embed_kernel_binary": True,
|
||||
"aot_inductor.emit_multi_arch_kernel": True,
|
||||
"aot_inductor.model_name_for_generated_files": "aoti_model",
|
||||
}
|
||||
result = maybe_aoti_standalone_config(patches)
|
||||
self.assertEqual(result, patches)
|
||||
|
@ -15,6 +15,7 @@ from typing import Callable
|
||||
from parameterized import parameterized_class
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._inductor.codecache import get_kernel_bin_format
|
||||
from torch._inductor.package import load_package, package_aoti
|
||||
from torch._inductor.test_case import TestCase
|
||||
@ -363,6 +364,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
)
|
||||
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||
@skipIfXpu # build system may be different
|
||||
@torch._inductor.config.patch("test_configs.use_libtorch", True)
|
||||
def test_compile_after_package_static(self):
|
||||
# compile_standalone will set package_cpp_only=True
|
||||
self.check_package_cpp_only()
|
||||
@ -419,12 +421,46 @@ class TestAOTInductorPackage(TestCase):
|
||||
with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
|
||||
self.cmake_compile(model, example_inputs, options, "")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||
@skipIfXpu # build system may be different
|
||||
@torch._inductor.config.patch("test_configs.use_libtorch", True)
|
||||
def test_compile_standalone_cos(self):
|
||||
# compile_standalone will set package_cpp_only=True
|
||||
self.check_package_cpp_only()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cos(x)
|
||||
|
||||
with torch.no_grad():
|
||||
example_inputs = (torch.randn(8, 32, device=self.device),)
|
||||
model = Model().to(device=self.device)
|
||||
|
||||
# Test compilation when model name is passed in
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor.model_name_for_generated_files": "cos",
|
||||
}
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmp_dir,
|
||||
):
|
||||
build_path, _ = self.cmake_compile(
|
||||
model, example_inputs, options, tmp_dir
|
||||
)
|
||||
# Check if the .a file was build successfully
|
||||
a_path = build_path / "libcos.a"
|
||||
self.assertTrue(a_path.exists())
|
||||
|
||||
@unittest.skipIf(
|
||||
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
|
||||
)
|
||||
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||
@skipIfRocm # doesn't support multi-arch binary
|
||||
@skipIfXpu # doesn't support multi-arch binary
|
||||
@torch._inductor.config.patch("test_configs.use_libtorch", True)
|
||||
def test_compile_with_exporter(self):
|
||||
self.check_package_cpp_only()
|
||||
|
||||
|
@ -1674,12 +1674,6 @@ class AotCodeCompiler:
|
||||
wrapper_code = "\n".join((wrapper_code, kernel_code))
|
||||
kernel_code = ""
|
||||
|
||||
from .utils import aoti_model_name_from_config
|
||||
|
||||
model_class_name = ""
|
||||
if config.aot_inductor.compile_standalone:
|
||||
model_class_name = aoti_model_name_from_config()
|
||||
|
||||
wrapper_key, wrapper_path = write(
|
||||
wrapper_code,
|
||||
"wrapper.cpp",
|
||||
@ -1712,6 +1706,8 @@ class AotCodeCompiler:
|
||||
"model.h",
|
||||
)
|
||||
) as f:
|
||||
# model_name_for_generated_files is guaranteed to be non-empty when compile_standalone
|
||||
model_class_name = config.aot_inductor.model_name_for_generated_files
|
||||
class_name = f"AOTInductorModel{model_class_name}"
|
||||
header_code = f.read()
|
||||
|
||||
@ -1726,7 +1722,7 @@ class AotCodeCompiler:
|
||||
header_code,
|
||||
"h",
|
||||
specified_dir=specified_output_path,
|
||||
key=f"{model_class_name}",
|
||||
key=model_class_name,
|
||||
)
|
||||
|
||||
# Log the AOTInductor wrapper and kernel code, if needed.
|
||||
@ -1840,7 +1836,7 @@ class AotCodeCompiler:
|
||||
consts_asm += f"\t.space {len(consts) - 8}\n"
|
||||
consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n"
|
||||
consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n"
|
||||
return consts_asm, "S"
|
||||
return consts_asm, "weights.S"
|
||||
|
||||
# Use c++ to convert consts to object file can support more compilers, such as msvc and icx.
|
||||
def format_consts_to_cpp(
|
||||
@ -1865,7 +1861,7 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
|
||||
const_cpp += "\t\n"
|
||||
const_cpp += "};\t\n"
|
||||
const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n"
|
||||
return const_cpp, "cpp"
|
||||
return const_cpp, "weights.cpp"
|
||||
|
||||
if use_asm_build:
|
||||
consts_code, code_ext = format_consts_to_asm(
|
||||
@ -1880,6 +1876,7 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
|
||||
consts_code,
|
||||
code_ext,
|
||||
specified_dir=str(specified_sub_dir),
|
||||
key=config.aot_inductor.model_name_for_generated_files,
|
||||
)
|
||||
consts_s = Path(consts_s)
|
||||
object_build_options = CppTorchDeviceOptions(
|
||||
@ -2173,7 +2170,13 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
|
||||
asm_files = []
|
||||
if not _IS_WINDOWS:
|
||||
ld, objcopy = get_ld_and_objcopy(use_relative_path)
|
||||
kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {})
|
||||
for kernel_name, value in CudaKernelParamCache.cache.items():
|
||||
if kernel_name not in kernels:
|
||||
# It is possible that CudaKernelParamCache contains more Triton kernels
|
||||
# than what the current graph uses
|
||||
continue
|
||||
|
||||
if asm_file := value["asm"]:
|
||||
asm_files.append(asm_file)
|
||||
|
||||
|
@ -22,13 +22,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import config, ir
|
||||
from ..utils import (
|
||||
_align,
|
||||
aoti_model_name_from_config,
|
||||
DeferredLineBase,
|
||||
LineContext,
|
||||
normalize_name,
|
||||
)
|
||||
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .common import get_device_op_overrides, IndentedBuffer, Kernel
|
||||
@ -64,11 +58,15 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
self.device = "cpu"
|
||||
# must be initialized prior to calling super().__init__()
|
||||
self.included_devices: OrderedSet[str] = OrderedSet()
|
||||
self.model_class_name_suffix = ""
|
||||
if config.aot_inductor.compile_standalone:
|
||||
self.model_class_name_suffix = aoti_model_name_from_config()
|
||||
self.model_class_name_suffix = (
|
||||
config.aot_inductor.model_name_for_generated_files
|
||||
if config.aot_inductor.compile_standalone
|
||||
else ""
|
||||
)
|
||||
self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}"
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.declare = "auto "
|
||||
self.declare_maybe_reference = "decltype(auto) "
|
||||
self.ending = ";"
|
||||
|
@ -4479,6 +4479,11 @@ class TritonScheduling(SIMDScheduling):
|
||||
kernel_name = "_".join(
|
||||
["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()]
|
||||
)
|
||||
if config.aot_inductor.model_name_for_generated_files:
|
||||
# When AOTI compiles multiple submodules, we need to use the model name to
|
||||
# distinguish kernel related symbols.
|
||||
kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}"
|
||||
|
||||
# use the original src_code as the key
|
||||
wrapper.src_to_kernel[src_code] = kernel_name
|
||||
subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"
|
||||
|
@ -1450,12 +1450,12 @@ class aot_inductor:
|
||||
precompile_headers: bool = not is_fbcode()
|
||||
|
||||
# Embed generated kernel binary files into model.so
|
||||
embed_kernel_binary: bool = False
|
||||
embed_kernel_binary: Optional[bool] = None
|
||||
|
||||
# Generate kernel files that support multiple archs
|
||||
# For CUDA, this means generating fatbin files for kernels, and the fatbin files
|
||||
# contains PTX and SASS for the current architecture.
|
||||
emit_multi_arch_kernel: bool = False
|
||||
emit_multi_arch_kernel: Optional[bool] = None
|
||||
|
||||
# If not None, the generated files with use this name in file stem.
|
||||
# If None, we will use a hash to name files.
|
||||
@ -1842,6 +1842,10 @@ class test_configs:
|
||||
|
||||
graphsafe_rng_func_ignores_fallback_random = False
|
||||
|
||||
# If set to True, AOTI-generated CMakelists.txt will still use libtorch
|
||||
# for unit testing
|
||||
use_libtorch = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
@ -28,7 +28,6 @@ from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor import config, exc
|
||||
from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import aoti_model_name_from_config
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
|
||||
@ -1545,7 +1544,9 @@ class CppBuilder:
|
||||
self._aot_mode: bool = False
|
||||
|
||||
self._name = name
|
||||
self._target_name = aoti_model_name_from_config()
|
||||
self._target_name = (
|
||||
config.aot_inductor.model_name_for_generated_files or "aoti_model"
|
||||
)
|
||||
|
||||
# Code start here, initial self internal variables firstly.
|
||||
self._build_option = BuildOption
|
||||
@ -1771,9 +1772,13 @@ class CppBuilder:
|
||||
"""
|
||||
|
||||
definitions = " ".join(self._build_option.get_definitions())
|
||||
target_library_type = (
|
||||
"STATIC" if config.aot_inductor.compile_standalone else "SHARED"
|
||||
)
|
||||
if config.aot_inductor.compile_standalone:
|
||||
if config.test_configs.use_libtorch:
|
||||
add_target = f"add_library({self._target_name} STATIC)"
|
||||
else:
|
||||
add_target = f"add_executable({self._target_name} ${{CMAKE_CURRENT_SOURCE_DIR}}/main.cpp)"
|
||||
else:
|
||||
add_target = f"add_library({self._target_name} SHARED)"
|
||||
|
||||
contents = textwrap.dedent(
|
||||
f"""
|
||||
@ -1781,22 +1786,54 @@ class CppBuilder:
|
||||
project({self._target_name} LANGUAGES CXX)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
# May need to point CMAKE_PREFIX_PATH to the right torch location
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
# Set a shared library target
|
||||
add_library({self._target_name} {target_library_type})
|
||||
|
||||
# Add macro definitions
|
||||
target_compile_definitions({self._target_name} PRIVATE {definitions})
|
||||
|
||||
# Add compile flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._cflags_args})
|
||||
# Backend specific flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
|
||||
# Set target
|
||||
{add_target}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
if (
|
||||
not config.aot_inductor.compile_standalone
|
||||
or config.test_configs.use_libtorch
|
||||
):
|
||||
# When compile_standalone is True, the generated cpp project should
|
||||
# not use Torch. But for unit testing purpose, we need to use Torch here.
|
||||
contents += textwrap.dedent(
|
||||
"""
|
||||
# May need to point CMAKE_PREFIX_PATH to the right torch location
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
"""
|
||||
)
|
||||
# flags and macros here are mostly CPU specific. Not emitting them for GPU models
|
||||
# will make the generated CMake file more portable and won't really hurt performance.
|
||||
# NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may
|
||||
# be still needed.
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
# Add macro definitions
|
||||
target_compile_definitions({self._target_name} PRIVATE {definitions})
|
||||
|
||||
# Add compile flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._cflags_args})
|
||||
|
||||
# Backend-specific flags
|
||||
target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c)
|
||||
|
||||
"""
|
||||
)
|
||||
else:
|
||||
# When compile_standalone is True, use TorchStandalone instead of Torch
|
||||
contents += textwrap.dedent(
|
||||
"""
|
||||
find_package(TorchStandalone REQUIRED)
|
||||
# Set up include directories to find headers at the correct paths
|
||||
target_include_directories(cos PRIVATE ${TorchStandalone_INCLUDE_DIRS})
|
||||
target_include_directories(cos PRIVATE ${TorchStandalone_INCLUDE_DIRS}/standalone)
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
if device_type == "cuda" and torch.version.hip is None:
|
||||
from torch._inductor.codecache import _nvcc_arch_as_compile_option
|
||||
|
||||
@ -1804,7 +1841,11 @@ class CppBuilder:
|
||||
contents += textwrap.dedent(
|
||||
f"""
|
||||
enable_language(CUDA)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}})
|
||||
target_compile_definitions({self._target_name} PRIVATE USE_CUDA)
|
||||
target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static)
|
||||
|
||||
find_program(OBJCOPY_EXECUTABLE objcopy)
|
||||
if(NOT OBJCOPY_EXECUTABLE)
|
||||
@ -1833,7 +1874,7 @@ class CppBuilder:
|
||||
add_custom_command(
|
||||
OUTPUT ${{FATBIN_FILE}}
|
||||
COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}}
|
||||
-gencode arch=compute_80,code=compute_80
|
||||
-gencode arch=compute_{current_arch},code=compute_{current_arch}
|
||||
-gencode arch=compute_{current_arch},code=sm_{current_arch}
|
||||
DEPENDS ${{PTX_FILE}}
|
||||
)
|
||||
@ -1882,12 +1923,20 @@ class CppBuilder:
|
||||
"""
|
||||
)
|
||||
f.write(contents)
|
||||
f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n")
|
||||
f.write(
|
||||
f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n"
|
||||
)
|
||||
if asm_files:
|
||||
f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n")
|
||||
f.write(
|
||||
f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n"
|
||||
)
|
||||
|
||||
def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
|
||||
if (
|
||||
config.aot_inductor.compile_standalone
|
||||
and not config.test_configs.use_libtorch
|
||||
):
|
||||
# When compile_standalone is True, do not link with libtorch
|
||||
return
|
||||
|
||||
lflags = " ".join(self._build_option.get_ldflags())
|
||||
libs = " ".join(self._build_option.get_libraries())
|
||||
contents = textwrap.dedent(
|
||||
|
@ -3309,20 +3309,34 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
|
||||
Returns:
|
||||
dict[str, Any]: The possibly-updated `config_patches` dictionary.
|
||||
"""
|
||||
|
||||
def patch_config(
|
||||
config_patches: dict[str, Any], config_name: str, config_value: Any
|
||||
) -> None:
|
||||
value = config_patches.get(config_name, getattr(config, config_name))
|
||||
if value is None:
|
||||
config_patches[config_name] = config_value
|
||||
elif not value:
|
||||
raise RuntimeError(
|
||||
f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True."
|
||||
)
|
||||
|
||||
compile_standalone = config_patches.get(
|
||||
"aot_inductor.compile_standalone", config.aot_inductor.compile_standalone
|
||||
)
|
||||
# Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
|
||||
config_patches = config_patches.copy()
|
||||
if compile_standalone:
|
||||
package_cpp_only = config_patches.get(
|
||||
"aot_inductor.package_cpp_only", config.aot_inductor.package_cpp_only
|
||||
# Standlaone AOTInductor means only generate cpp project for building a standalone binary
|
||||
patch_config(config_patches, "aot_inductor.package_cpp_only", True)
|
||||
# Standlaone AOTInductor needs to embed the kernel code in the binary
|
||||
patch_config(config_patches, "aot_inductor.embed_kernel_binary", True)
|
||||
# Default to use multi-arch kernel codegen
|
||||
patch_config(config_patches, "aot_inductor.emit_multi_arch_kernel", True)
|
||||
patch_config(
|
||||
config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
|
||||
)
|
||||
if package_cpp_only is None:
|
||||
config_patches = {**config_patches, "aot_inductor.package_cpp_only": True}
|
||||
elif not package_cpp_only:
|
||||
raise RuntimeError(
|
||||
"compile_standalone=True requires package_cpp_only=True. "
|
||||
"Please set aot_inductor.package_cpp_only=True in your inductor config."
|
||||
)
|
||||
|
||||
return config_patches
|
||||
|
||||
|
||||
@ -3351,11 +3365,3 @@ def is_valid_aoti_model_name() -> bool:
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def aoti_model_name_from_config() -> str:
|
||||
from torch._inductor import config
|
||||
|
||||
model_name = config.aot_inductor.model_name_for_generated_files
|
||||
model_name = "aoti_model" if model_name is None else model_name
|
||||
return model_name
|
||||
|
Reference in New Issue
Block a user