Enabling StaticCudaLauncher for ROCm

This commit is contained in:
Chinmay Kuchinad
2025-10-16 00:45:38 +00:00
committed by PyTorch MergeBot
parent 341e924981
commit 2691b25b6a
6 changed files with 161 additions and 43 deletions

View File

@ -12,7 +12,6 @@ from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaK
from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton
from torch._inductor.runtime.triton_helpers import libdevice
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda_and_triton
@ -39,8 +38,9 @@ class TestStaticCudaLauncher(TestCase):
# Just used by tests for now.
# TODO: derive cubin_path from wherever triton stores the cubin file on disk.
tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False)
binary_key = "hsaco" if torch.version.hip else "cubin"
with tmp_file:
tmp_file.write(kernel.asm["cubin"])
tmp_file.write(kernel.asm[binary_key])
self.tmp_files.append(tmp_file)
return tmp_file.name
@ -64,7 +64,6 @@ class TestStaticCudaLauncher(TestCase):
result.load_kernel(device_interface.current_device())
return result
@skipIfRocm
def test_basic(self):
@triton.jit
def simple_kernel(arg0, arg1):
@ -91,7 +90,6 @@ class TestStaticCudaLauncher(TestCase):
# 2. triton relies on inspect.get_source to get the type annotations
# so I can't even use exec() to generate the test cases.
# So we'll just make a few kernels by hand
@skipIfRocm
def test_unsigned_integers(self):
@triton.jit
def unsigned_integers(
@ -115,7 +113,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_signed_integers(self):
@triton.jit
def signed_integers(
@ -139,7 +136,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_basic_1arg(self):
@triton.jit
def simple_kernel_1_arg(arg0):
@ -164,7 +160,6 @@ class TestStaticCudaLauncher(TestCase):
)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_constexpr(self):
# Constexprs are compiled directly into the cubin file,
# so we never need to pass it to StaticCudaLauncher.
@ -193,7 +188,6 @@ class TestStaticCudaLauncher(TestCase):
)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_implied_constant(self):
"""xnumel is unused in this kernel, but isn't explicitly marked as a constexpr"""
@ -246,7 +240,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, arg0, arg2, 128)
self.assertEqual(arg1, arg2)
@skipIfRocm
def test_kernel_no_args(self):
# Just an easy way to test incompatible number of arguments
@triton.jit
@ -259,7 +252,6 @@ class TestStaticCudaLauncher(TestCase):
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run(1, 1, 1, stream)
@skipIfRocm
def test_high_shared_mem(self):
@triton.jit
def simple_kernel(arg0, arg1):
@ -283,7 +275,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, new_arg0, arg1)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_too_high_shared_mem(self):
@triton.jit
def simple_kernel(arg0, arg1):
@ -303,7 +294,6 @@ class TestStaticCudaLauncher(TestCase):
lambda: self._make_launcher(compiled_kernel),
)
@skipIfRocm
def test_kernel_empty_tensor(self):
# Triton kernel generated by torch.compile of the following:
# @torch.compile()
@ -364,7 +354,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel)
self.assertEqual(buf0, buf1)
@skipIfRocm
def test_kernel_many_args(self):
N = 200
# Make 200 arguments
@ -405,7 +394,6 @@ class TestStaticTritonCompileResult(TestCase):
Tests static cuda launcher with torch.compile()
"""
@skipIfRocm
def test_basic_compile(self):
@torch.compile
def foo(x, y):
@ -415,7 +403,6 @@ class TestStaticTritonCompileResult(TestCase):
y = torch.randn(10, device="cuda")
self.assertEqual(foo(x, y), x + y)
@skipIfRocm
# The error gets raised on a worker, so we want to not use a separate process
@torch._inductor.config.patch("compile_threads", 1)
def test_incompatible_code(self):
@ -438,7 +425,6 @@ class TestStaticTritonCompileResult(TestCase):
lambda: foo(x),
)
@skipIfRocm
# The error gets raised on a worker, so we want to not use a separate process
@torch._inductor.config.patch(
{"compile_threads": 1, "static_launch_user_defined_triton_kernels": True}
@ -460,7 +446,6 @@ class TestStaticTritonCompileResult(TestCase):
x2 = x.clone().detach_()
self.assertEqual(foo(x), x2 + 5)
@skipIfRocm
def test_empty_tensor(self):
@torch.compile()
def foo(x, y):
@ -472,7 +457,6 @@ class TestStaticTritonCompileResult(TestCase):
result = foo(x, y)
self.assertEqual(result, torch.cat(((x * 4), y + 10)))
@skipIfRocm
def test_any(self):
def fn(x):
return (
@ -492,7 +476,6 @@ class TestStaticTritonCompileResult(TestCase):
compiled_result = compiled_fn(arg)
self.assertEqual(eager_result, compiled_result)
@skipIfRocm
def test_disable_static_cuda_launcher(self):
@torch.compile
def fn(x, y):

View File

@ -37,9 +37,17 @@ class StaticallyLaunchedCudaKernel:
def __init__(self, kernel: CompiledKernel) -> None:
# pyrefly: ignore [missing-attribute]
self.name = kernel.src.fn.__name__
# pyrefly: ignore [missing-attribute]
self.cubin_raw = kernel.asm.get("cubin", None)
# pyrefly: ignore [missing-attribute]
if "hsaco" in kernel.asm:
self.cubin_raw = kernel.asm["hsaco"]
self.is_rocm = True
elif "cubin" in kernel.asm:
self.cubin_raw = kernel.asm["cubin"]
self.is_rocm = False
else:
raise RuntimeError(
"Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm"
)
self.cubin_path = kernel._cubin_path
# Used by torch.compile to filter constants in older triton versions
@ -245,13 +253,42 @@ class StaticallyLaunchedCudaKernel:
# thing, it should always match.
# Get rid of constants before passing to cubin launcher
# Add a None if triton wants extra parameters for scratch spaces
arg_tys = self.arg_tys
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
if has_scratch:
arg_tys = arg_tys + "O"
args = (*args, None)
# pyrefly: ignore [bad-argument-type]
if self.is_rocm:
# ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both
# global_scratch and profile_scratch parameters in the kernel signature,
# even when the kernel doesn't use them (i.e., when has_*_scratch is False).
#
# This differs fundamentally from CUDA, where these parameters are only
# present in the signature if the corresponding has_*_scratch flag is True.
#
# The flags indicate whether memory will be allocated/used:
# - has_global_scratch: Whether global scratch workspace is needed
# - has_profile_scratch: Whether profiling instrumentation is enabled
#
# However, regardless of flag values, we MUST always pass both parameters
# to match the HIP kernel ABI. Passing None is safe:
#
# - If scratch is not needed (has_*_scratch=False or scratch_size=0):
# The None becomes nullptr, which the kernel never dereferences
#
# - If scratch is needed (has_*_scratch=True and scratch_size>0):
# The None becomes nullptr initially, but the HIP runtime intercepts
# the kernel launch, allocates the required scratch memory based on
# kernel metadata, and replaces the nullptr with a valid pointer before
# the kernel actually executes
#
# Not passing both parameters causes segmentation faults because the kernel
# expects them at specific positions in the argument array.
arg_tys = arg_tys + "OO"
args = (*args, None, None)
else:
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
if has_scratch:
arg_tys = arg_tys + "O"
args = (*args, None)
assert len(args) == len(arg_tys)
# TODO: can handle grid functions here or in C++, so

View File

@ -1593,9 +1593,8 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
return None
def check_can_launch() -> StaticallyLaunchedCudaKernel:
if triton_meta.get("device_type") != "cuda":
# Only cuda kernels
raise CannotStaticallyLaunchKernel("Non-cuda device")
if triton_meta.get("device_type") not in ("cuda", "hip"):
raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device")
if torch._inductor.config.cpp_wrapper:
# If we're running with cpp wrapper, it doesn't
@ -1621,10 +1620,11 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
"static launch does not support launch attributes"
)
binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin"
cubin_location = os.path.join(
triton_cache_dir(triton_meta.get("device", 0)),
triton_hash_to_path_key(kernel.hash),
f"{kernel.src.fn.__name__}.cubin",
f"{kernel.src.fn.__name__}.{binary_ext}",
)
if not os.path.exists(cubin_location):

View File

@ -2159,7 +2159,7 @@ PyObject* initModule() {
#ifdef USE_CUDA
torch::cuda::initModule(module);
#endif
#if defined(USE_CUDA) && !defined(USE_ROCM)
#if defined(USE_CUDA)
ASSERT_TRUE(StaticCudaLauncher_init(module));
#endif
#ifdef USE_MPS

View File

@ -1,7 +1,4 @@
#if defined(USE_CUDA) && !defined(USE_ROCM)
// We disable this file from being hipified because there are CUDA drivers hip
// has not implemented yet. Also, we're passing in a cubin file directly, so it
// would take more work to support ROCM anyway.
#if defined(USE_CUDA) || defined(USE_ROCM)
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <ATen/Context.h>
@ -16,6 +13,11 @@
#include <torch/csrc/utils/python_numbers.h>
#include <filesystem>
#include <optional>
#if defined(USE_ROCM)
#include <hip/hip_runtime_api.h>
#endif
/**
Implements a static launcher for triton compiled CUDA kernels.
Given a path to a cubin file, a function name, and some metadata,
@ -56,8 +58,15 @@ const at::cuda::NVRTC& nvrtc() {
CUdeviceptr getPointer(PyObject* obj) {
CUdeviceptr data_ptr = 0;
if (THPUtils_checkLong(obj)) {
data_ptr = THPUtils_unpackUInt64(obj);
#if defined(USE_ROCM)
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(obj));
#else
data_ptr = THPUtils_unpackUInt64(obj);
#endif
return data_ptr;
}
if (obj == Py_None) {
@ -73,13 +82,25 @@ CUdeviceptr getPointer(PyObject* obj) {
TORCH_CHECK(
THPUtils_checkLong(ret),
"data_ptr method of Pointer object must return 64-bit int");
#if defined(USE_ROCM)
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(ret));
#else
data_ptr = THPUtils_unpackUInt64(ret);
#endif
if (!data_ptr)
return data_ptr;
CUdeviceptr dev_ptr = 0;
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute(
&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute(
&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
#endif
return dev_ptr;
}
@ -98,6 +119,18 @@ CUfunction loadKernel(
}
CUmodule mod = nullptr;
CUfunction func = nullptr;
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK(
hipModuleGetFunction(&func, mod, funcName.c_str()));
int shared_optin = 0;
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
&shared_optin,
hipDeviceAttributeSharedMemPerBlockOptin,
device));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK(
nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str()));
@ -106,6 +139,9 @@ CUfunction loadKernel(
&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
#endif
// Shared memory logic from triton/third-party/nvidia/backend/driver.c
// If we're using more than 48 KB of shared memory, and we have
// access to more than 48 KB of shared memory on the device,
@ -124,6 +160,23 @@ CUfunction loadKernel(
" Reducing block sizes or `num_stages` may help.");
if (sharedMemBytes > SHARED_MEM_STATIC_MAX &&
shared_optin > SHARED_MEM_STATIC_MAX) {
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(
hipFuncSetCacheConfig(func, hipFuncCachePreferShared));
int shared_total = 0, shared_static = 0;
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
&shared_total,
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor,
device));
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
&shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
#else
AT_CUDA_DRIVER_CHECK(
nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total = 0, shared_static = 0;
@ -136,7 +189,8 @@ CUfunction loadKernel(
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
shared_optin - shared_static));
#endif
}
return func;
}
@ -152,6 +206,27 @@ inline void launchKernel(
cudaStream_t stream) {
// cta_args is always 1 for inductor generated triton kernels,
// so we don't need to figure out grid dimension here
#if defined(USE_ROCM)
int device = 0;
AT_CUDA_DRIVER_CHECK(hipGetDevice(&device));
int warp_size = 0;
AT_CUDA_DRIVER_CHECK(
hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device));
AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel(
func,
gridX,
gridY,
gridZ,
warp_size * numWarps, // blockDim.x
1, // blockDim.y
1, // blockDim.z
sharedMemBytes,
stream,
args,
nullptr));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
func,
gridX,
@ -164,6 +239,7 @@ inline void launchKernel(
stream,
args,
nullptr));
#endif
}
template <typename FINAL, typename F>
@ -269,11 +345,21 @@ PyObject* load_kernel(PyObject* self, PyObject* args) {
CUdevice device = static_cast<CUdevice>(device_ptr); // NOLINT
CUfunction func = nullptr;
func = loadKernel(filePath, funcName, sharedMemBytes, device);
// Taken from triton/nvidia/backend/driver.c
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(
hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func));
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
&n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
#else
AT_CUDA_DRIVER_CHECK(
nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func));
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute(
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
#endif
n_spills /= 4;
// Return a tuple of CUFunction, n_regs, n_spills
return Py_BuildValue(
@ -299,7 +385,6 @@ PyObject* launch_kernel_inner(
std::array<uint64_t, MAX_ARGS> argStorage = {};
std::array<void*, MAX_ARGS> kernelArgs = {};
parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data());
launchKernel(
func,
gridX,
@ -386,13 +471,26 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
CUcontext pctx = nullptr;
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
#endif
if (!pctx) {
// Ensure device context exists
CUdevice device = 0;
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0));
AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device));
AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0));
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device));
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx));
#endif
}
CUfunction func = reinterpret_cast<CUfunction>(func_ptr); // NOLINT
cudaStream_t cudaStream = reinterpret_cast<cudaStream_t>(stream); // NOLINT

View File

@ -1,5 +1,5 @@
#pragma once
#if defined(USE_CUDA) && !defined(USE_ROCM)
#if defined(USE_CUDA)
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>
#include <torch/csrc/python_headers.h>