Compare commits

...

1 Commits

Author SHA1 Message Date
6d9f2792a8 Enable multi-arch unit tests for ROCm 2025-10-28 01:36:17 +00:00
5 changed files with 63 additions and 32 deletions

View File

@ -238,12 +238,11 @@ class AOTInductorTestsTemplate:
"toolchain doesn't support ptx to fatbin",
)
@skipIfMPS
@skipIfRocm
# Skip embed_kernel_binary == True for now as it shows random
# failure on CI
@common_utils.parametrize("embed_kernel_binary", [False])
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
torch.version.hip is None and _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
def test_simple_multi_arch(self, embed_kernel_binary):
if self.device != GPU_TYPE:
@ -273,7 +272,14 @@ class AOTInductorTestsTemplate:
_, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
file_extension = ".spv" if self.device == "xpu" else ".fatbin"
file_extension = ""
if self.device == "xpu":
file_extension = ".spv"
elif self.device == "cuda" and torch.version.hip is not None:
file_extension = ".hsaco"
else:
file_extension = ".fatbin"
FileCheck().check(file_extension).run(code)
def test_small_constant(self):

View File

@ -320,10 +320,9 @@ class TestAOTInductorPackage(TestCase):
self.assertTrue(torch.allclose(actual, expected))
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
torch.version.hip is None and _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary
def test_compile_after_package_multi_arch(self):
if self.device != GPU_TYPE:
@ -462,10 +461,9 @@ class TestAOTInductorPackage(TestCase):
self.assertTrue(a_path.exists())
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
torch.version.hip is None and _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary
@torch._inductor.config.patch("test_configs.use_libtorch", True)
def test_compile_with_exporter(self):
@ -520,10 +518,9 @@ class TestAOTInductorPackage(TestCase):
)
@unittest.skipIf(
_get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
torch.version.hip is None and _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+"
)
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary
@torch._inductor.config.patch("test_configs.use_libtorch", True)
def test_compile_with_exporter_weights(self):

View File

@ -1680,9 +1680,9 @@ class CudaKernelParamCache:
basename, _ = get_name_and_dir_from_output_file_path(bin_path)
if config.aot_inductor.emit_multi_arch_kernel:
bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"}
bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv", "hsaco": ".hsaco"}
assert bin_type in bin_type_to_ext.keys(), (
"multi_arch_kernel_binary only supported in CUDA/XPU"
"multi_arch_kernel_binary only supported in CUDA/XPU/ROCm"
)
base_path, _ = os.path.splitext(bin_path)
bin_path = base_path + bin_type_to_ext[bin_type]
@ -1692,18 +1692,26 @@ class CudaKernelParamCache:
config.aot_inductor.emit_multi_arch_kernel
or config.aot_inductor.package_cpp_only
):
assert asm, "Missing kernel assembly code"
assert asm_type, "Missing kernel assembly type"
_, asm_path = write(
asm,
asm_type,
hash_type=asm_type,
specified_dir=split_aot_inductor_output_path(
config.aot_inductor.output_path
)[0],
# make sure asm file has the same basename
key=basename,
)
# CUDA/XPU: require 'asm' (PTX/SPV). ROCm: allow proceeding without 'asm'
# and defaults to using the binary file provided by triton
if torch.version.hip is None or (asm and asm_type):
assert asm, "Missing kernel assembly code"
assert asm_type, "Missing kernel assembly type"
# Hashing only recognizes {"amdgcn","ptx","spv"}.
# For other text types (e.g., "ll", "bc"), fall back to "code".
hash_kind = asm_type if asm_type in {"amdgcn", "ptx", "spv"} else "code"
_, asm_path = write(
asm,
asm_type,
hash_type=hash_kind,
specified_dir=split_aot_inductor_output_path(
config.aot_inductor.output_path
)[0],
# make sure asm file has the same basename
key=basename,
)
params[get_cpp_wrapper_cubin_path_name()] = bin_path
params["asm"] = asm_path
@ -2358,6 +2366,7 @@ end
if (
config.aot_inductor.emit_multi_arch_kernel
and device_type == "cuda"
and torch.version.hip is None
):
current_arch = _nvcc_arch_as_compile_option()
cmd = (
@ -2381,6 +2390,15 @@ end
)
raise
elif (
config.aot_inductor.emit_multi_arch_kernel
and device_type == "cuda"
and torch.version.hip is not None
):
# ROCm: There is no fatbin analog; we deliberately do nothing here.
# Triton already produced HSACO. We will just embed that HSACO below.
pass
if config.aot_inductor.embed_kernel_binary:
# Embed cubin files into model.so using objcopy
cubins_o.append(
@ -2446,10 +2464,17 @@ end
generated_files.append(consts_o)
so_builder.save_src_to_cmake(cmake_path, consts_o)
if config.aot_inductor.emit_multi_arch_kernel:
# ROCm: DO NOT recompile asm in CMake
# CUDA multi-arch -> embed_gpu_kernel()
if (
config.aot_inductor.emit_multi_arch_kernel
and torch.version.hip is None
):
so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files)
generated_files.extend(asm_files)
else:
# ROCm -> just link prebuilt objects
elif torch.version.hip is not None:
obj_srcs = [*gpu_kernels_o, *cubins_o]
generated_files.extend(obj_srcs)
for obj in obj_srcs:

View File

@ -417,13 +417,15 @@ class _ExportPackage:
path = Path(base_directory) / f"{name}_input_{i}.pt"
torch.save(t, path)
cmake_file_str = _get_make_file(package_name, model_names, use_cuda)
# Detect if ROCm is being used
is_hip = bool(getattr(torch.version, "hip", None))
cmake_file_str = _get_make_file(package_name, model_names, use_cuda, is_hip)
with open(Path(base_directory) / "CMakeLists.txt", "w") as file:
file.write(cmake_file_str)
main_file_str = _get_main_cpp_file(
package_name, model_names, use_cuda, example_inputs_map
package_name, model_names, use_cuda, example_inputs_map, is_hip
)
with open(Path(base_directory) / "main.cpp", "w") as file:
file.write(main_file_str)

View File

@ -13,6 +13,7 @@ def _get_main_cpp_file(
model_names: list[str],
cuda: bool,
example_inputs_map: typing.Optional[dict[str, int]],
is_hip: bool = False
) -> str:
"""
Generates a main.cpp file for AOTInductor standalone models in the specified package.
@ -42,7 +43,7 @@ def _get_main_cpp_file(
"#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>",
]
)
if cuda:
if cuda and not is_hip:
ib.writelines(
[
"#include <cuda.h>",
@ -181,7 +182,7 @@ def _get_main_cpp_file(
return ib.getvalue()
def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str:
def _get_make_file(package_name: str, model_names: list[str], cuda: bool, is_hip: bool = False) -> str:
ib = IndentedBuffer()
ib.writelines(
@ -199,7 +200,7 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str
if test_configs.use_libtorch:
ib.writeline("find_package(Torch REQUIRED)")
if cuda:
if cuda and not is_hip:
ib.writeline("find_package(CUDA REQUIRED)")
ib.newline()
@ -207,13 +208,13 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str
ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)")
ib.writeline("\nadd_executable(main main.cpp)")
if cuda:
if cuda and not is_hip:
ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)")
model_libs = " ".join(model_names)
ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})")
if cuda:
if cuda and not is_hip:
ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})")
return ib.getvalue()