[AOTI] Embed cubin files into .so (#150739)

Summary: Embed cubin files so AOTI is one step closer to generate a single binary. Controlled by a flag and off as default.

Differential Revision: [D72535357](https://our.internmc.facebook.com/intern/diff/D72535357)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150739
Approved by: https://github.com/angelayi
This commit is contained in:
Bin Bao
2025-05-18 13:24:40 -07:00
committed by PyTorch MergeBot
parent a8986963da
commit a2d0ef242d
10 changed files with 170 additions and 35 deletions

View File

@ -124,7 +124,8 @@ except (unittest.SkipTest, ImportError):
class AOTInductorTestsTemplate:
def test_simple(self):
@common_utils.parametrize("embed_cubin", [False, True])
def test_simple(self, embed_cubin):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -138,7 +139,18 @@ class AOTInductorTestsTemplate:
torch.randn(10, 10, device=self.device),
)
model = Model()
self.check_model(model, example_inputs)
with config.patch({"aot_inductor.embed_cubin": embed_cubin}):
self.check_model(model, example_inputs)
_, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
if self.device == GPU_TYPE:
FileCheck().check("launchKernel(").run(code)
if config.aot_inductor.embed_cubin:
# Not expect to see launchKernel("CUBIN_FILE_NAME"
FileCheck().check_not('launchKernel("').run(code)
if self.use_minimal_arrayref_interface:
self.code_check_count(
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
@ -3234,7 +3246,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), inputs)
def test_repeated_user_defined_triton_kernel(self):
@common_utils.parametrize("embed_cubin", [False, True])
def test_repeated_user_defined_triton_kernel(self, embed_cubin):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@ -3248,7 +3261,14 @@ class AOTInductorTestsTemplate:
return x
inputs = (torch.randn(4, 4, device=self.device),)
self.check_model(Model(), inputs)
with config.patch({"aot_inductor.embed_cubin": embed_cubin}):
model = Model()
self.check_model(model, inputs)
_, code = run_and_get_cpp_code(AOTIRunnerUtil.compile, model, inputs)
FileCheck().check("launchKernel(").run(code)
if config.aot_inductor.embed_cubin:
# Not expect to see launchKernel("CUBIN_FILE_NAME"
FileCheck().check_not('launchKernel("').run(code)
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"

View File

@ -205,6 +205,8 @@ class TestAOTInductorPackage(TestCase):
options = {
"aot_inductor.package_cpp_only": self.package_cpp_only,
# Require kernels to be compiled into .o files
"aot_inductor.embed_cubin": True,
}
ep = torch.export.export(model, example_inputs, strict=True)
package_path = torch._inductor.aoti_compile_and_package(
@ -216,6 +218,10 @@ class TestAOTInductorPackage(TestCase):
zip_ref.extractall(tmp_dir)
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
self.assertTrue(tmp_path.exists())
if self.device == GPU_TYPE:
self.assertTrue(not list(tmp_path.glob("*.cubin")))
self.assertTrue(list(tmp_path.glob("*.cubin.o")))
build_path = tmp_path / "build"
self.assertTrue(not build_path.exists())

View File

@ -79,6 +79,21 @@ class TestCppWrapperHipify(TestCase):
return func;
}
static inline hipFunction_t loadKernel(const void* start, const std::string &funcName, uint32_t sharedMemBytes) {
hipModule_t mod;
hipFunction_t func;
CUDA_DRIVER_CHECK(hipModuleLoadData(&mod, start));
CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(hipFuncSetAttribute(
func,
hipFuncAttributeMaxDynamicSharedMemorySize,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
hipFunction_t func,
uint32_t gridX,

View File

@ -62,10 +62,12 @@ from torch._inductor.cpp_builder import (
_set_gpu_runtime_env,
_TORCH_PATH,
_transform_cuda_paths,
convert_cubin_to_obj,
CppBuilder,
CppOptions,
CppTorchDeviceOptions,
get_compiler_version_info,
get_ld_and_objcopy,
get_name_and_dir_from_output_file_path,
normalize_path_separator,
)
@ -1960,7 +1962,16 @@ class AotCodeCompiler:
for entry in gpu_codecache.cache.values()
if entry.output_path.endswith(".o")
]
gpu_kernels_o = " ".join(gpu_kernels_o)
cubins_o = []
if config.aot_inductor.embed_cubin:
# Embed cubin files into .so using objcopy
ld, objcopy = get_ld_and_objcopy(use_relative_path)
for kernel_name, value in CudaKernelParamCache.cache.items():
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
cubins_o.append(
convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy)
)
output_name, output_dir = get_name_and_dir_from_output_file_path(output_so)
so_build_options = CppTorchDeviceOptions(
@ -1970,11 +1981,10 @@ class AotCodeCompiler:
use_relative_path=use_relative_path,
)
obj_srcs = [wrapper_o, kernel_o, consts_o, *gpu_kernels_o, *cubins_o]
so_builder = CppBuilder(
name=output_name,
sources=[wrapper_o, kernel_o, consts_o, gpu_kernels_o]
if gpu_kernels_o
else [wrapper_o, kernel_o, consts_o],
sources=obj_srcs,
output_dir=output_dir,
BuildOption=so_build_options,
)
@ -2019,17 +2029,14 @@ class AotCodeCompiler:
generated_files.append(weight_file)
generated_files.append(consts_o)
generated_files.append(gpu_kernels_o)
so_builder.save_src_to_cmake(cmake_path, consts_o)
for gpu_o in gpu_kernels_o.split():
so_builder.save_src_to_cmake(cmake_path, gpu_o)
obj_srcs = [consts_o, *gpu_kernels_o, *cubins_o]
generated_files.extend(obj_srcs)
for obj in obj_srcs:
so_builder.save_src_to_cmake(cmake_path, obj)
so_builder.save_link_cmd_to_cmake(cmake_path)
else:
so_builder.build()
for o_file in [wrapper_o, kernel_o, consts_o]:
for o_file in obj_srcs:
# Remove these as they are not needed anymore
os.remove(o_file)

View File

@ -666,7 +666,15 @@ class CppWrapperCpu(PythonWrapperCodegen):
signature = kernel.get_signature().replace(name, kernel_ptr)
self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};")
self.prefix.writeline("};")
self.prefix.writeline("} // namespace")
self.prefix.writeline("} // namespace\n\n")
if config.aot_inductor.embed_cubin:
self.prefix.writeline('extern "C" {')
for name in sorted(declare_kernel):
self.prefix.writeline(
f" extern const unsigned char __{name}_start[];"
)
self.prefix.writeline("}")
def codegen_model_constructor(self):
"""

View File

@ -58,6 +58,9 @@ class DeferredTritonCallWrapper:
arg_types: list[Any]
def generate(self, wrapper: CppWrapperGpu):
"""
Generate the GPU kernel definition, as well as load and launch code.
"""
prefix = wrapper.prefix
if self.kernel_name.startswith("multi_kernel_"):
# MultiKernel will select one kernel after running the autotune block
@ -132,10 +135,12 @@ class DeferredTritonCallWrapper:
self.generate_load_kernel(prefix, kernel_var_name, params)
self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params)
prefix.writeline("}")
# Ensure the cubin file is included in the package
V.graph.wrapper_code.additional_files.append(
params[get_cpp_wrapper_cubin_path_name()]
)
if not config.aot_inductor.embed_cubin:
# Ensure the cubin file is included in the package
V.graph.wrapper_code.additional_files.append(
params[get_cpp_wrapper_cubin_path_name()]
)
def generate_grid(
self,
@ -160,12 +165,20 @@ class DeferredTritonCallWrapper:
def generate_load_kernel(self, prefix, kernel_var_name, params):
prefix.writeline(f"if ({kernel_var_name} == nullptr) {{")
with prefix.indent():
load_kernel_args = [
cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]),
cpp_string_literal(params["mangled_name"]),
str(params["shared_mem"]),
"cubin_dir_",
]
load_kernel_args = (
[
f"__{params['inductor_meta']['kernel_name']}_start",
cpp_string_literal(params["mangled_name"]),
str(params["shared_mem"]),
]
if V.graph.aot_mode and config.aot_inductor.embed_cubin
else [
cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]),
cpp_string_literal(params["mangled_name"]),
str(params["shared_mem"]),
"cubin_dir_",
]
)
prefix.writeline(
f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); "
)

View File

@ -88,6 +88,21 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
return func;
}
static inline CUfunction loadKernel(const void* start, const std::string &funcName, uint32_t sharedMemBytes) {
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoadData(&mod, start));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,

View File

@ -1295,6 +1295,9 @@ class aot_inductor:
# Experimental. Controls automatic precompiling of common AOTI include files.
precompile_headers: bool = not is_fbcode()
# Embed generated .cubin files into the .so
embed_cubin: bool = False
class cuda:
"""Settings for cuda backend, today this consists of cutlass"""

View File

@ -157,6 +157,51 @@ def get_cpp_compiler() -> str:
return compiler
def get_ld_and_objcopy(use_relative_path: bool) -> tuple[str, str]:
if _IS_WINDOWS:
raise RuntimeError("Windows is not supported yet.")
else:
if config.is_fbcode():
ld = build_paths.ld
objcopy = (
build_paths.objcopy_fallback
if use_relative_path
else build_paths.objcopy
)
else:
ld = "ld"
objcopy = "objcopy"
return ld, objcopy
def convert_cubin_to_obj(
cubin_file: str,
kernel_name: str,
ld: str,
objcopy: str,
) -> str:
obj_file = cubin_file + ".o"
# Convert .cubin to .o
cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}"
subprocess.run(cmd.split(), capture_output=True, text=True)
os.remove(cubin_file)
# Rename .data to .rodata
cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}"
subprocess.run(cmd.split(), capture_output=True, text=True)
# By default objcopy will create *_start, *_size, *_end symbols using the full path
# Rename to use the unique kernel name
file_name = re.sub(r"[\W]", "_", cubin_file)
cmd = (
objcopy
+ f" --redefine-sym _binary_{file_name}_start=__{kernel_name}_start "
+ f"--redefine-sym _binary_{file_name}_size=__{kernel_name}_size "
+ f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end "
+ obj_file
)
subprocess.run(cmd.split(), capture_output=True, text=True)
return obj_file
@functools.lru_cache(None)
def _is_apple_clang(cpp_compiler: str) -> bool:
version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8")

View File

@ -264,7 +264,7 @@ bool recursive_rmdir(const std::string& path) {
std::string compile_so(
const std::string& cpp_filename,
const std::string& consts_filename) {
std::vector<std::string>& obj_filenames) {
// Compile the cpp file into a .so
size_t lastindex = cpp_filename.find_last_of('.');
@ -280,8 +280,9 @@ std::string compile_so(
cpp_filename.substr(0, lastindex) + "_linker_flags.json";
const nlohmann::json linker_flags = load_json_file(linker_flags_path);
auto [link_cmd, output_so] = get_cpp_compile_command(
filename, {output_o, consts_filename}, linker_flags);
obj_filenames.push_back(output_o);
auto [link_cmd, output_so] =
get_cpp_compile_command(filename, obj_filenames, linker_flags);
// Run the commands to generate a .so file
int status = system(compile_cmd.c_str());
@ -369,7 +370,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
temp_dir_ = create_temp_dir();
std::string so_filename;
std::string cpp_filename;
std::string consts_filename;
std::vector<std::string> obj_filenames;
std::string found_filenames; // Saving for bookkeeping
std::string model_directory =
"data" + k_separator + "aotinductor" + k_separator + model_name;
@ -408,8 +409,10 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
if (lastSlash != std::string::npos) {
filename = filename_str.substr(lastSlash + 1);
}
output_path_str +=
k_separator + model_directory + k_separator + filename;
output_path_str.append(k_separator)
.append(model_directory)
.append(k_separator)
.append(filename);
}
LOG(INFO) << "Extract file: " << filename_str << " to "
@ -440,7 +443,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
if (filename_extension == ".cpp") {
cpp_filename = output_path_str;
} else if (filename_extension == ".o") {
consts_filename = output_path_str;
obj_filenames.push_back(output_path_str);
} else if (filename_extension == ".so") {
so_filename = output_path_str;
}
@ -465,7 +468,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
// Compile the .so
std::string so_path = !so_filename.empty()
? so_filename
: compile_so(cpp_filename, consts_filename);
: compile_so(cpp_filename, obj_filenames);
// Load metadata which can be queried by user
load_metadata(cpp_filename);