mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
We previously asked users to seperate these because we didn't have any way of adding extern C declarations. Now we don't and we don't need this confusing flag anymore BC breaking but is fine for this API since it doesn't have major users yet. Please just put your all your code in `kernel_source` moving forward ## BC note The header_code parameter has been removed from torch.cuda._compile_kernel. Previously, users could pass separate header code that would be prepended to the kernel source. Now, header code must be included directly in the kernel_source parameter. Note this only affects torch.cuda._compile_kernel, which is a private API. Example: Before ```python kernel = compile_kernel( kernel_source="global void my_kernel() { ... }", kernel_name="my_kernel", header_code="#define SCALE 2.0f\n__device_ float scale(float x) { return x * SCALE; }" ) ``` After ```python kernel_source = """ #define SCALE 2.0f device float scale(float x) { return x * SCALE; } global void my_kernel() { ... } """ kernel = _compile_kernel(kernel_source, "my_kernel") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163165 Approved by: https://github.com/janeyx99, https://github.com/albanD
517 lines
18 KiB
Python
517 lines
18 KiB
Python
import ctypes
|
|
import sys
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
|
|
# The _get_device_index has been moved to torch.utils._get_device_index
|
|
from torch._utils import _get_device_index as _torch_get_device_index
|
|
|
|
|
|
def _get_hip_runtime_library() -> ctypes.CDLL:
|
|
if sys.platform == "win32":
|
|
lib = ctypes.CDLL(f"amdhip64_{torch.version.hip[0]}.dll")
|
|
else: # Unix-based systems
|
|
lib = ctypes.CDLL("libamdhip64.so")
|
|
lib.cuGetErrorString = lib.hipGetErrorString # type: ignore[attr-defined]
|
|
lib.cuModuleLoadData = lib.hipModuleLoadData # type: ignore[attr-defined]
|
|
lib.cuModuleGetFunction = lib.hipModuleGetFunction # type: ignore[attr-defined]
|
|
lib.cuLaunchKernel = lib.hipModuleLaunchKernel # type: ignore[attr-defined]
|
|
lib.cuFuncSetAttribute = lib.hipFuncSetAttribute # type: ignore[attr-defined]
|
|
return lib
|
|
|
|
|
|
def _get_cuda_runtime_library() -> ctypes.CDLL:
|
|
if sys.platform == "win32":
|
|
return ctypes.CDLL("nvcuda.dll")
|
|
else: # Unix-based systems
|
|
return ctypes.CDLL("libcuda.so.1")
|
|
|
|
|
|
# Load GPU driver runtime
|
|
def _get_gpu_runtime_library() -> ctypes.CDLL:
|
|
if torch.version.hip:
|
|
return _get_hip_runtime_library()
|
|
else:
|
|
return _get_cuda_runtime_library()
|
|
|
|
|
|
# Helper: check CUDA errors
|
|
def _check_cuda(result: int) -> None:
|
|
if result == 0:
|
|
return
|
|
err_str = ctypes.c_char_p()
|
|
libcuda = _get_gpu_runtime_library() # Get reference to CUDA library
|
|
libcuda.cuGetErrorString(result, ctypes.byref(err_str))
|
|
error_message = (
|
|
err_str.value.decode() if err_str.value is not None else "Unknown CUDA error"
|
|
)
|
|
raise RuntimeError(f"CUDA error: {error_message}")
|
|
|
|
|
|
def _get_hiprtc_library() -> ctypes.CDLL:
|
|
if sys.platform == "win32":
|
|
version_str = "".join(["0", torch.version.hip[0], "0", torch.version.hip[2]])
|
|
lib = ctypes.CDLL(f"hiprtc{version_str}.dll")
|
|
else:
|
|
lib = ctypes.CDLL("libhiprtc.so")
|
|
|
|
# Provide aliases for HIP RTC functions to match NVRTC API
|
|
lib.nvrtcGetErrorString = lib.hiprtcGetErrorString # type: ignore[attr-defined]
|
|
lib.nvrtcCreateProgram = lib.hiprtcCreateProgram # type: ignore[attr-defined]
|
|
lib.nvrtcDestroyProgram = lib.hiprtcDestroyProgram # type: ignore[attr-defined]
|
|
lib.nvrtcCompileProgram = lib.hiprtcCompileProgram # type: ignore[attr-defined]
|
|
lib.nvrtcGetPTXSize = lib.hiprtcGetCodeSize # type: ignore[attr-defined]
|
|
lib.nvrtcGetPTX = lib.hiprtcGetCode # type: ignore[attr-defined]
|
|
lib.nvrtcGetProgramLogSize = lib.hiprtcGetProgramLogSize # type: ignore[attr-defined]
|
|
lib.nvrtcGetProgramLog = lib.hiprtcGetProgramLog # type: ignore[attr-defined]
|
|
lib.nvrtcAddNameExpression = lib.hiprtcAddNameExpression # type: ignore[attr-defined]
|
|
lib.nvrtcGetLoweredName = lib.hiprtcGetLoweredName # type: ignore[attr-defined]
|
|
return lib
|
|
|
|
|
|
def _get_nvrtc_library() -> ctypes.CDLL:
|
|
if sys.platform == "win32":
|
|
return ctypes.CDLL("nvrtc64_120_0.dll")
|
|
else:
|
|
return ctypes.CDLL("libnvrtc.so")
|
|
|
|
|
|
def _get_gpu_rtc_library() -> ctypes.CDLL:
|
|
# Since PyTorch already loads the GPU RTC library, we can use the system library
|
|
# which should be compatible with PyTorch's version
|
|
if torch.version.hip:
|
|
return _get_hiprtc_library()
|
|
else:
|
|
return _get_nvrtc_library()
|
|
|
|
|
|
def _get_gpu_rtc_compatible_flags() -> list[str]:
|
|
"""
|
|
Get HIPCC/NVCC flags that are compatible with NVRTC compilation.
|
|
|
|
Returns:
|
|
List of HIPCC/NVCC flags that can be safely used with NVRTC.
|
|
"""
|
|
from torch.utils.cpp_extension import COMMON_HIPCC_FLAGS, COMMON_NVCC_FLAGS
|
|
|
|
nvrtc_unsupported_flags = {
|
|
"--expt-relaxed-constexpr",
|
|
}
|
|
|
|
# Filter out unsupported flags
|
|
compatible_flags = [
|
|
flag for flag in COMMON_NVCC_FLAGS if flag not in nvrtc_unsupported_flags
|
|
]
|
|
|
|
if torch.version.hip:
|
|
compatible_flags.extend(COMMON_HIPCC_FLAGS)
|
|
|
|
return compatible_flags
|
|
|
|
|
|
def _nvrtc_compile(
|
|
kernel_source: str,
|
|
kernel_name: str,
|
|
compute_capability: Optional[str] = None,
|
|
cuda_include_dirs: Optional[list] = None,
|
|
nvcc_options: Optional[list] = None,
|
|
auto_pch: bool = False,
|
|
) -> tuple[bytes, str]:
|
|
"""
|
|
Compiles a CUDA kernel using NVRTC and returns the PTX code.
|
|
|
|
Args:
|
|
kernel_source (str): The CUDA kernel source code as a string
|
|
kernel_name (str): The name of the kernel function to compile
|
|
compute_capability (str, None): The compute capability to target (e.g., "86").
|
|
If None, will detect from current device.
|
|
cuda_include_dirs (list, None): List of directories containing CUDA headers
|
|
nvcc_options (list, None): Additional options to pass to NVRTC
|
|
auto_pch (bool): Enable automatic precompiled headers (CUDA 12.8+)
|
|
|
|
Returns:
|
|
Tuple[bytes, str]: The compiled PTX code and mangled kernel name
|
|
"""
|
|
# Ensure CUDA is initialized
|
|
import torch.cuda
|
|
|
|
# Load NVRTC library
|
|
libnvrtc = _get_gpu_rtc_library()
|
|
|
|
# NVRTC constants
|
|
NVRTC_SUCCESS = 0
|
|
|
|
# Helper: check NVRTC errors
|
|
def check_nvrtc(result: int) -> None:
|
|
if result != NVRTC_SUCCESS:
|
|
err_str = ctypes.c_char_p()
|
|
libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str))
|
|
error_message = (
|
|
err_str.value.decode()
|
|
if err_str.value is not None
|
|
else "Unknown CUDA error"
|
|
)
|
|
raise RuntimeError(f"CUDA error: {error_message}")
|
|
|
|
# Convert source to bytes
|
|
source_bytes = kernel_source.encode("utf-8")
|
|
|
|
# Get compute capability if not provided
|
|
if compute_capability is None:
|
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
if torch.version.hip:
|
|
compute_capability = f"{props.gcnArchName}"
|
|
else:
|
|
compute_capability = f"{props.major}{props.minor}"
|
|
|
|
# Prepare compilation options
|
|
options = []
|
|
if torch.version.hip:
|
|
options.append(f"--offload-arch={compute_capability}".encode())
|
|
else:
|
|
options.append(f"--gpu-architecture=sm_{compute_capability}".encode())
|
|
|
|
# Auto-detect and add CUDA include paths
|
|
from torch.utils.cpp_extension import include_paths
|
|
|
|
cuda_include_paths = include_paths("cuda")
|
|
for cuda_path in cuda_include_paths:
|
|
options.append(f"-I{cuda_path}".encode())
|
|
|
|
# Add custom include directories
|
|
if cuda_include_dirs:
|
|
for directory in cuda_include_dirs:
|
|
options.append(f"-I{directory}".encode())
|
|
|
|
# Enable automatic precompiled headers (CUDA 12.8+)
|
|
if auto_pch:
|
|
assert str(torch.version.cuda) >= "12.8", "PCH requires CUDA 12.8+"
|
|
if nvcc_options is None:
|
|
nvcc_options = []
|
|
nvcc_options.append("--pch")
|
|
|
|
# Add custom NVCC options
|
|
if nvcc_options:
|
|
for option in nvcc_options:
|
|
options.append(option.encode("utf-8"))
|
|
|
|
nvrtc_compatible_flags = _get_gpu_rtc_compatible_flags()
|
|
options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags])
|
|
|
|
# Convert options to C array
|
|
num_options = len(options)
|
|
options_array = (ctypes.c_char_p * num_options)(*options)
|
|
|
|
# Create program
|
|
prog = ctypes.c_void_p()
|
|
check_nvrtc(
|
|
libnvrtc.nvrtcCreateProgram(
|
|
ctypes.byref(prog),
|
|
source_bytes,
|
|
f"{kernel_name}.cu".encode(),
|
|
0,
|
|
None,
|
|
None,
|
|
)
|
|
)
|
|
|
|
# Add kernel name, which can be a template expression
|
|
c_kernel_name = kernel_name.encode("utf-8")
|
|
check_nvrtc(libnvrtc.nvrtcAddNameExpression(prog, c_kernel_name))
|
|
|
|
# Compile program
|
|
res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array)
|
|
|
|
# Handle compilation errors
|
|
if res != NVRTC_SUCCESS:
|
|
# Get log
|
|
log_size = ctypes.c_size_t()
|
|
libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size))
|
|
log = ctypes.create_string_buffer(log_size.value)
|
|
libnvrtc.nvrtcGetProgramLog(prog, log)
|
|
raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}")
|
|
|
|
# Get PTX
|
|
ptx_size = ctypes.c_size_t()
|
|
check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size)))
|
|
ptx = ctypes.create_string_buffer(ptx_size.value)
|
|
check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx))
|
|
|
|
# Get mangled name
|
|
c_mangled_name = ctypes.c_char_p()
|
|
check_nvrtc(
|
|
libnvrtc.nvrtcGetLoweredName(prog, c_kernel_name, ctypes.byref(c_mangled_name))
|
|
)
|
|
if c_mangled_name.value is not None:
|
|
mangled_name = c_mangled_name.value.decode() # make a copy
|
|
else:
|
|
mangled_name = ""
|
|
|
|
libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog))
|
|
|
|
# For HIP, hipRTC generates raw CO binaries instead of PTX,
|
|
# and for some reason, ".value" causes the string to be truncated,
|
|
# likely due to the presence of '\0' in the string. So we use .raw instead.
|
|
ptx_bytes = ptx.raw if torch.version.hip else ptx.value
|
|
return ptx_bytes, mangled_name
|
|
|
|
|
|
class _CudaModule:
|
|
def __init__(self, module: ctypes.c_void_p) -> None:
|
|
self._module = module
|
|
self._kernels: dict[str, _CudaKernel] = {}
|
|
|
|
def __getattr__(self, name: str) -> "_CudaKernel":
|
|
if name in self._kernels:
|
|
return self._kernels[name]
|
|
|
|
# Import the CUDA library inside the method
|
|
from torch.cuda._utils import _get_gpu_runtime_library
|
|
|
|
libcuda = _get_gpu_runtime_library()
|
|
|
|
func = ctypes.c_void_p()
|
|
try:
|
|
_check_cuda(
|
|
libcuda.cuModuleGetFunction(
|
|
ctypes.byref(func), self._module, name.encode("utf-8")
|
|
)
|
|
)
|
|
kernel = _CudaKernel(func, self._module)
|
|
self._kernels[name] = kernel
|
|
return kernel
|
|
|
|
except RuntimeError as err:
|
|
raise AttributeError(f"No kernel named '{name}' in this module") from err
|
|
|
|
|
|
class _CudaKernel:
|
|
"""
|
|
Represents a compiled CUDA kernel that can be called with PyTorch tensors.
|
|
"""
|
|
|
|
def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None:
|
|
self.func = func
|
|
self.module = module
|
|
self._max_shared_mem_bytes = 0
|
|
|
|
def __call__(
|
|
self,
|
|
grid: tuple[int, int, int] = (1, 1, 1),
|
|
block: tuple[int, int, int] = (1, 1, 1),
|
|
args: Optional[list] = None,
|
|
shared_mem: int = 0,
|
|
stream: Optional[Any] = None,
|
|
) -> None:
|
|
"""
|
|
Call the compiled CUDA kernel
|
|
|
|
Args:
|
|
grid (tuple): Grid dimensions (grid_x, grid_y, grid_z)
|
|
block (tuple): Block dimensions (block_x, block_y, block_z)
|
|
args (list): List of arguments to pass to the kernel.
|
|
PyTorch tensor arguments will be automatically converted to pointers.
|
|
shared_mem (int): Shared memory size in bytes
|
|
stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream.
|
|
"""
|
|
import torch
|
|
|
|
libcuda = torch.cuda._utils._get_gpu_runtime_library()
|
|
|
|
if not args:
|
|
args = []
|
|
|
|
# Process arguments and convert tensors to pointers
|
|
processed_args: list[ctypes.c_void_p] = []
|
|
c_args = []
|
|
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()):
|
|
raise ValueError(
|
|
"All tensor arguments must be CUDA tensors or pinned CPU tensors"
|
|
)
|
|
# Get pointer to tensor data
|
|
ptr = ctypes.c_void_p(arg.data_ptr())
|
|
processed_args.append(ptr)
|
|
c_args.append(ctypes.byref(ptr))
|
|
elif isinstance(arg, int):
|
|
# Convert integers to C int
|
|
c_int = ctypes.c_int(arg)
|
|
# Store the C int for reference keeping, not in processed_args
|
|
c_args.append(ctypes.byref(c_int))
|
|
elif isinstance(arg, float):
|
|
# Python floats are doubles - use double by default
|
|
c_double = ctypes.c_double(arg)
|
|
# Store the C double for reference keeping, not in processed_args
|
|
c_args.append(ctypes.byref(c_double))
|
|
else:
|
|
raise TypeError(f"Unsupported argument type: {type(arg)}")
|
|
|
|
# Convert to array of void pointers
|
|
c_args_array = (ctypes.c_void_p * len(c_args))()
|
|
for i, arg in enumerate(c_args):
|
|
c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p)
|
|
|
|
# Get the stream
|
|
if stream is None:
|
|
# Defer import to avoid circular imports
|
|
import torch.cuda
|
|
|
|
stream = torch.cuda.current_stream()
|
|
|
|
# Check if kernel requires large shared memory but hasn't been configured
|
|
if shared_mem >= 48 * 1024 and (
|
|
self._max_shared_mem_bytes == 0 or shared_mem > self._max_shared_mem_bytes
|
|
):
|
|
configured_msg = (
|
|
"not configured"
|
|
if self._max_shared_mem_bytes == 0
|
|
else f"only {self._max_shared_mem_bytes} bytes configured"
|
|
)
|
|
raise RuntimeError(
|
|
f"Kernel requires {shared_mem} bytes of shared memory (>= 48KB), "
|
|
f"but {configured_msg}. "
|
|
"Call kernel.set_shared_memory_config(shared_mem) after compilation "
|
|
"and before launching the kernel."
|
|
)
|
|
|
|
_check_cuda(
|
|
libcuda.cuLaunchKernel(
|
|
self.func,
|
|
grid[0],
|
|
grid[1],
|
|
grid[2],
|
|
block[0],
|
|
block[1],
|
|
block[2],
|
|
shared_mem,
|
|
stream._as_parameter_,
|
|
c_args_array,
|
|
None,
|
|
)
|
|
)
|
|
|
|
def set_shared_memory_config(self, shared_mem_bytes: int) -> None:
|
|
if shared_mem_bytes < 48 * 1024:
|
|
# No configuration needed for <= 48KB, just update the value
|
|
self._max_shared_mem_bytes = shared_mem_bytes
|
|
return
|
|
|
|
libcuda = _get_gpu_runtime_library()
|
|
|
|
# Get device properties to validate against limits
|
|
device_props = torch.cuda.get_device_properties()
|
|
# HIP doesn't have shared_memory_per_block_optin in device properties, so we hard-code it here
|
|
if torch.version.hip:
|
|
# navi, CDNA1-CDNA3 allows a max of 64KB shared memory
|
|
# CDNA4 allows a max of 160KB shared memory
|
|
max_shared_mem = (
|
|
65536 if device_props.gcnArchName not in ["gfx950"] else 160 * 1024
|
|
)
|
|
else:
|
|
max_shared_mem = getattr(
|
|
device_props, "shared_memory_per_block_optin", 49152
|
|
)
|
|
|
|
if shared_mem_bytes > max_shared_mem:
|
|
raise RuntimeError(
|
|
f"Requested shared memory ({shared_mem_bytes} bytes) exceeds "
|
|
f"device limit ({max_shared_mem} bytes). "
|
|
"Consider reducing block size or shared memory usage."
|
|
)
|
|
|
|
# Set the function attribute once
|
|
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize = 8
|
|
_check_cuda(
|
|
libcuda.cuFuncSetAttribute(
|
|
self.func,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
shared_mem_bytes,
|
|
)
|
|
)
|
|
|
|
self._max_shared_mem_bytes = shared_mem_bytes
|
|
|
|
|
|
def _cuda_load_module(
|
|
ptx: Union[str, bytes], kernel_names: Optional[list[str]] = None
|
|
) -> Union[_CudaModule, dict[str, "_CudaKernel"]]:
|
|
"""
|
|
Loads a CUDA module from PTX code and returns a module object that can access kernels.
|
|
|
|
Args:
|
|
ptx (bytes or str): The PTX code to load
|
|
kernel_names (list, optional): List of kernel names to extract from the module.
|
|
If None, will return a module object with __getattr__.
|
|
|
|
Returns:
|
|
object: If kernel_names is None, returns a module object with __getattr__ to access kernels.
|
|
If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects.
|
|
"""
|
|
# Ensure CUDA is initialized
|
|
import torch.cuda
|
|
|
|
# Load CUDA driver library
|
|
libcuda = _get_gpu_runtime_library()
|
|
|
|
# Convert PTX to bytes if it's a string
|
|
if isinstance(ptx, str):
|
|
ptx = ptx.encode("utf-8")
|
|
|
|
# Load PTX module
|
|
module = ctypes.c_void_p()
|
|
# Get the current stream without directly importing torch.cuda at module level
|
|
stream = torch.cuda.current_stream()
|
|
with stream:
|
|
_check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx))
|
|
|
|
if not kernel_names:
|
|
return _CudaModule(module)
|
|
|
|
# Return specific kernels
|
|
kernels = {}
|
|
for name in kernel_names:
|
|
func = ctypes.c_void_p()
|
|
_check_cuda(
|
|
libcuda.cuModuleGetFunction(
|
|
ctypes.byref(func), module, name.encode("utf-8")
|
|
)
|
|
)
|
|
kernels[name] = _CudaKernel(func, module)
|
|
return kernels
|
|
|
|
|
|
def _get_device_index(
|
|
device: Any, optional: bool = False, allow_cpu: bool = False
|
|
) -> int:
|
|
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
|
|
|
|
If :attr:`device` is a torch.device object, returns the device index if it
|
|
is a CUDA device. Note that for a CUDA device without a specified index,
|
|
i.e., ``torch.device('cuda')``, this will return the current default CUDA
|
|
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
|
CPU devices will be accepted and ``-1`` will be returned in this case.
|
|
|
|
If :attr:`device` is a Python integer, it is returned as is.
|
|
|
|
If :attr:`device` is ``None``, this will return the current default CUDA
|
|
device if :attr:`optional` is ``True``.
|
|
"""
|
|
if isinstance(device, int):
|
|
return device
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
if isinstance(device, torch.device):
|
|
if allow_cpu:
|
|
if device.type not in ["cuda", "cpu"]:
|
|
raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
|
|
elif device.type != "cuda":
|
|
raise ValueError(f"Expected a cuda device, but got: {device}")
|
|
if not torch.jit.is_scripting():
|
|
if isinstance(device, torch.cuda.device):
|
|
return device.idx
|
|
return _torch_get_device_index(device, optional, allow_cpu)
|