mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d after: 0 errors (1,970 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588 Approved by: https://github.com/oulgen
530 lines
19 KiB
Python
530 lines
19 KiB
Python
import ctypes
|
|
import sys
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
|
|
# The _get_device_index has been moved to torch.utils._get_device_index
|
|
from torch._utils import _get_device_index as _torch_get_device_index
|
|
|
|
|
|
def _get_hip_runtime_library() -> ctypes.CDLL:
|
|
if sys.platform == "win32":
|
|
lib = ctypes.CDLL(f"amdhip64_{torch.version.hip[0]}.dll")
|
|
else: # Unix-based systems
|
|
lib = ctypes.CDLL("libamdhip64.so")
|
|
lib.cuGetErrorString = lib.hipGetErrorString # type: ignore[attr-defined]
|
|
lib.cuModuleLoadData = lib.hipModuleLoadData # type: ignore[attr-defined]
|
|
lib.cuModuleGetFunction = lib.hipModuleGetFunction # type: ignore[attr-defined]
|
|
lib.cuLaunchKernel = lib.hipModuleLaunchKernel # type: ignore[attr-defined]
|
|
lib.cuFuncSetAttribute = lib.hipFuncSetAttribute # type: ignore[attr-defined]
|
|
return lib
|
|
|
|
|
|
def _get_cuda_runtime_library() -> ctypes.CDLL:
|
|
if sys.platform == "win32":
|
|
return ctypes.CDLL("nvcuda.dll")
|
|
else: # Unix-based systems
|
|
return ctypes.CDLL("libcuda.so.1")
|
|
|
|
|
|
# Load GPU driver runtime
|
|
def _get_gpu_runtime_library() -> ctypes.CDLL:
|
|
if torch.version.hip:
|
|
return _get_hip_runtime_library()
|
|
else:
|
|
return _get_cuda_runtime_library()
|
|
|
|
|
|
# Helper: check CUDA errors
|
|
def _check_cuda(result: int) -> None:
|
|
if result == 0:
|
|
return
|
|
err_str = ctypes.c_char_p()
|
|
libcuda = _get_gpu_runtime_library() # Get reference to CUDA library
|
|
libcuda.cuGetErrorString(result, ctypes.byref(err_str))
|
|
error_message = (
|
|
err_str.value.decode() if err_str.value is not None else "Unknown CUDA error"
|
|
)
|
|
raise RuntimeError(f"CUDA error: {error_message}")
|
|
|
|
|
|
def _get_hiprtc_library() -> ctypes.CDLL:
|
|
if sys.platform == "win32":
|
|
version_str = "".join(["0", torch.version.hip[0], "0", torch.version.hip[2]])
|
|
lib = ctypes.CDLL(f"hiprtc{version_str}.dll")
|
|
else:
|
|
lib = ctypes.CDLL("libhiprtc.so")
|
|
|
|
# Provide aliases for HIP RTC functions to match NVRTC API
|
|
lib.nvrtcGetErrorString = lib.hiprtcGetErrorString # type: ignore[attr-defined]
|
|
lib.nvrtcCreateProgram = lib.hiprtcCreateProgram # type: ignore[attr-defined]
|
|
lib.nvrtcDestroyProgram = lib.hiprtcDestroyProgram # type: ignore[attr-defined]
|
|
lib.nvrtcCompileProgram = lib.hiprtcCompileProgram # type: ignore[attr-defined]
|
|
lib.nvrtcGetPTXSize = lib.hiprtcGetCodeSize # type: ignore[attr-defined]
|
|
lib.nvrtcGetPTX = lib.hiprtcGetCode # type: ignore[attr-defined]
|
|
lib.nvrtcGetProgramLogSize = lib.hiprtcGetProgramLogSize # type: ignore[attr-defined]
|
|
lib.nvrtcGetProgramLog = lib.hiprtcGetProgramLog # type: ignore[attr-defined]
|
|
lib.nvrtcAddNameExpression = lib.hiprtcAddNameExpression # type: ignore[attr-defined]
|
|
lib.nvrtcGetLoweredName = lib.hiprtcGetLoweredName # type: ignore[attr-defined]
|
|
return lib
|
|
|
|
|
|
def _get_nvrtc_library() -> ctypes.CDLL:
|
|
major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr]
|
|
if sys.platform == "win32":
|
|
nvrtc_libs = [
|
|
f"nvrtc64_{major_version}0_0.dll",
|
|
]
|
|
else:
|
|
nvrtc_libs = [
|
|
f"libnvrtc.so.{major_version}",
|
|
"libnvrtc.so", # Fallback to unversioned
|
|
]
|
|
for lib_name in nvrtc_libs:
|
|
try:
|
|
return ctypes.CDLL(lib_name)
|
|
except OSError:
|
|
continue
|
|
raise OSError("Could not find any NVRTC library")
|
|
|
|
|
|
def _get_gpu_rtc_library() -> ctypes.CDLL:
|
|
# Since PyTorch already loads the GPU RTC library, we can use the system library
|
|
# which should be compatible with PyTorch's version
|
|
if torch.version.hip:
|
|
return _get_hiprtc_library()
|
|
else:
|
|
return _get_nvrtc_library()
|
|
|
|
|
|
def _get_gpu_rtc_compatible_flags() -> list[str]:
|
|
"""
|
|
Get HIPCC/NVCC flags that are compatible with NVRTC compilation.
|
|
|
|
Returns:
|
|
List of HIPCC/NVCC flags that can be safely used with NVRTC.
|
|
"""
|
|
from torch.utils.cpp_extension import COMMON_HIPCC_FLAGS, COMMON_NVCC_FLAGS
|
|
|
|
nvrtc_unsupported_flags = {
|
|
"--expt-relaxed-constexpr",
|
|
}
|
|
|
|
# Filter out unsupported flags
|
|
compatible_flags = [
|
|
flag for flag in COMMON_NVCC_FLAGS if flag not in nvrtc_unsupported_flags
|
|
]
|
|
|
|
if torch.version.hip:
|
|
compatible_flags.extend(COMMON_HIPCC_FLAGS)
|
|
|
|
return compatible_flags
|
|
|
|
|
|
def _nvrtc_compile(
|
|
kernel_source: str,
|
|
kernel_name: str,
|
|
compute_capability: Optional[str] = None,
|
|
cuda_include_dirs: Optional[list] = None,
|
|
nvcc_options: Optional[list] = None,
|
|
auto_pch: bool = False,
|
|
) -> tuple[bytes, str]:
|
|
"""
|
|
Compiles a CUDA kernel using NVRTC and returns the PTX code.
|
|
|
|
Args:
|
|
kernel_source (str): The CUDA kernel source code as a string
|
|
kernel_name (str): The name of the kernel function to compile
|
|
compute_capability (str, None): The compute capability to target (e.g., "86").
|
|
If None, will detect from current device.
|
|
cuda_include_dirs (list, None): List of directories containing CUDA headers
|
|
nvcc_options (list, None): Additional options to pass to NVRTC
|
|
auto_pch (bool): Enable automatic precompiled headers (CUDA 12.8+)
|
|
|
|
Returns:
|
|
Tuple[bytes, str]: The compiled PTX code and mangled kernel name
|
|
"""
|
|
# Ensure CUDA is initialized
|
|
import torch.cuda
|
|
|
|
# Load NVRTC library
|
|
libnvrtc = _get_gpu_rtc_library()
|
|
|
|
# NVRTC constants
|
|
NVRTC_SUCCESS = 0
|
|
|
|
# Helper: check NVRTC errors
|
|
def check_nvrtc(result: int) -> None:
|
|
if result != NVRTC_SUCCESS:
|
|
err_str = ctypes.c_char_p()
|
|
libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str))
|
|
error_message = (
|
|
err_str.value.decode()
|
|
if err_str.value is not None
|
|
else "Unknown CUDA error"
|
|
)
|
|
raise RuntimeError(f"CUDA error: {error_message}")
|
|
|
|
# Convert source to bytes
|
|
source_bytes = kernel_source.encode("utf-8")
|
|
|
|
# Get compute capability if not provided
|
|
if compute_capability is None:
|
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
if torch.version.hip:
|
|
compute_capability = f"{props.gcnArchName}"
|
|
else:
|
|
compute_capability = f"{props.major}{props.minor}"
|
|
|
|
# Prepare compilation options
|
|
options = []
|
|
if torch.version.hip:
|
|
options.append(f"--offload-arch={compute_capability}".encode())
|
|
else:
|
|
options.append(f"--gpu-architecture=sm_{compute_capability}".encode())
|
|
|
|
# Auto-detect and add CUDA include paths
|
|
from torch.utils.cpp_extension import include_paths
|
|
|
|
cuda_include_paths = include_paths("cuda")
|
|
for cuda_path in cuda_include_paths:
|
|
options.append(f"-I{cuda_path}".encode())
|
|
|
|
# Add custom include directories
|
|
if cuda_include_dirs:
|
|
for directory in cuda_include_dirs:
|
|
options.append(f"-I{directory}".encode())
|
|
|
|
# Enable automatic precompiled headers (CUDA 12.8+)
|
|
if auto_pch:
|
|
assert str(torch.version.cuda) >= "12.8", "PCH requires CUDA 12.8+"
|
|
if nvcc_options is None:
|
|
nvcc_options = []
|
|
nvcc_options.append("--pch")
|
|
|
|
# Add custom NVCC options
|
|
if nvcc_options:
|
|
for option in nvcc_options:
|
|
options.append(option.encode("utf-8"))
|
|
|
|
nvrtc_compatible_flags = _get_gpu_rtc_compatible_flags()
|
|
options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags])
|
|
|
|
# Convert options to C array
|
|
num_options = len(options)
|
|
options_array = (ctypes.c_char_p * num_options)(*options)
|
|
|
|
# Create program
|
|
prog = ctypes.c_void_p()
|
|
check_nvrtc(
|
|
libnvrtc.nvrtcCreateProgram(
|
|
ctypes.byref(prog),
|
|
source_bytes,
|
|
f"{kernel_name}.cu".encode(),
|
|
0,
|
|
None,
|
|
None,
|
|
)
|
|
)
|
|
|
|
# Add kernel name, which can be a template expression
|
|
c_kernel_name = kernel_name.encode("utf-8")
|
|
check_nvrtc(libnvrtc.nvrtcAddNameExpression(prog, c_kernel_name))
|
|
|
|
# Compile program
|
|
res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array)
|
|
|
|
# Handle compilation errors
|
|
if res != NVRTC_SUCCESS:
|
|
# Get log
|
|
log_size = ctypes.c_size_t()
|
|
libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size))
|
|
log = ctypes.create_string_buffer(log_size.value)
|
|
libnvrtc.nvrtcGetProgramLog(prog, log)
|
|
raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}")
|
|
|
|
# Get PTX
|
|
ptx_size = ctypes.c_size_t()
|
|
check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size)))
|
|
ptx = ctypes.create_string_buffer(ptx_size.value)
|
|
check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx))
|
|
|
|
# Get mangled name
|
|
c_mangled_name = ctypes.c_char_p()
|
|
check_nvrtc(
|
|
libnvrtc.nvrtcGetLoweredName(prog, c_kernel_name, ctypes.byref(c_mangled_name))
|
|
)
|
|
if c_mangled_name.value is not None:
|
|
mangled_name = c_mangled_name.value.decode() # make a copy
|
|
else:
|
|
mangled_name = ""
|
|
|
|
libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog))
|
|
|
|
# For HIP, hipRTC generates raw CO binaries instead of PTX,
|
|
# and for some reason, ".value" causes the string to be truncated,
|
|
# likely due to the presence of '\0' in the string. So we use .raw instead.
|
|
ptx_bytes = ptx.raw if torch.version.hip else ptx.value
|
|
return ptx_bytes, mangled_name
|
|
|
|
|
|
class _CudaModule:
|
|
def __init__(self, module: ctypes.c_void_p) -> None:
|
|
self._module = module
|
|
self._kernels: dict[str, _CudaKernel] = {}
|
|
|
|
def __getattr__(self, name: str) -> "_CudaKernel":
|
|
if name in self._kernels:
|
|
return self._kernels[name]
|
|
|
|
# Import the CUDA library inside the method
|
|
# pyrefly: ignore # missing-module-attribute
|
|
from torch.cuda._utils import _get_gpu_runtime_library
|
|
|
|
libcuda = _get_gpu_runtime_library()
|
|
|
|
func = ctypes.c_void_p()
|
|
try:
|
|
_check_cuda(
|
|
libcuda.cuModuleGetFunction(
|
|
ctypes.byref(func), self._module, name.encode("utf-8")
|
|
)
|
|
)
|
|
kernel = _CudaKernel(func, self._module)
|
|
self._kernels[name] = kernel
|
|
return kernel
|
|
|
|
except RuntimeError as err:
|
|
raise AttributeError(f"No kernel named '{name}' in this module") from err
|
|
|
|
|
|
class _CudaKernel:
|
|
"""
|
|
Represents a compiled CUDA kernel that can be called with PyTorch tensors.
|
|
"""
|
|
|
|
def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None:
|
|
self.func = func
|
|
self.module = module
|
|
self._max_shared_mem_bytes = 0
|
|
|
|
def __call__(
|
|
self,
|
|
grid: tuple[int, int, int] = (1, 1, 1),
|
|
block: tuple[int, int, int] = (1, 1, 1),
|
|
args: Optional[list] = None,
|
|
shared_mem: int = 0,
|
|
stream: Optional[Any] = None,
|
|
) -> None:
|
|
"""
|
|
Call the compiled CUDA kernel
|
|
|
|
Args:
|
|
grid (tuple): Grid dimensions (grid_x, grid_y, grid_z)
|
|
block (tuple): Block dimensions (block_x, block_y, block_z)
|
|
args (list): List of arguments to pass to the kernel.
|
|
PyTorch tensor arguments will be automatically converted to pointers.
|
|
shared_mem (int): Shared memory size in bytes
|
|
stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream.
|
|
"""
|
|
import torch
|
|
|
|
libcuda = torch.cuda._utils._get_gpu_runtime_library()
|
|
|
|
if not args:
|
|
args = []
|
|
|
|
# Process arguments and convert tensors to pointers
|
|
processed_args: list[ctypes.c_void_p] = []
|
|
c_args = []
|
|
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()):
|
|
raise ValueError(
|
|
"All tensor arguments must be CUDA tensors or pinned CPU tensors"
|
|
)
|
|
# Get pointer to tensor data
|
|
ptr = ctypes.c_void_p(arg.data_ptr())
|
|
processed_args.append(ptr)
|
|
c_args.append(ctypes.byref(ptr))
|
|
elif isinstance(arg, int):
|
|
# Convert integers to C int
|
|
c_int = ctypes.c_int(arg)
|
|
# Store the C int for reference keeping, not in processed_args
|
|
c_args.append(ctypes.byref(c_int))
|
|
elif isinstance(arg, float):
|
|
# Python floats are doubles - use double by default
|
|
c_double = ctypes.c_double(arg)
|
|
# Store the C double for reference keeping, not in processed_args
|
|
c_args.append(ctypes.byref(c_double))
|
|
else:
|
|
raise TypeError(f"Unsupported argument type: {type(arg)}")
|
|
|
|
# Convert to array of void pointers
|
|
c_args_array = (ctypes.c_void_p * len(c_args))()
|
|
for i, arg in enumerate(c_args):
|
|
c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p)
|
|
|
|
# Get the stream
|
|
if stream is None:
|
|
# Defer import to avoid circular imports
|
|
import torch.cuda
|
|
|
|
stream = torch.cuda.current_stream()
|
|
|
|
# Check if kernel requires large shared memory but hasn't been configured
|
|
if shared_mem >= 48 * 1024 and (
|
|
self._max_shared_mem_bytes == 0 or shared_mem > self._max_shared_mem_bytes
|
|
):
|
|
configured_msg = (
|
|
"not configured"
|
|
if self._max_shared_mem_bytes == 0
|
|
else f"only {self._max_shared_mem_bytes} bytes configured"
|
|
)
|
|
raise RuntimeError(
|
|
f"Kernel requires {shared_mem} bytes of shared memory (>= 48KB), "
|
|
f"but {configured_msg}. "
|
|
"Call kernel.set_shared_memory_config(shared_mem) after compilation "
|
|
"and before launching the kernel."
|
|
)
|
|
|
|
_check_cuda(
|
|
libcuda.cuLaunchKernel(
|
|
self.func,
|
|
grid[0],
|
|
grid[1],
|
|
grid[2],
|
|
block[0],
|
|
block[1],
|
|
block[2],
|
|
shared_mem,
|
|
stream._as_parameter_,
|
|
c_args_array,
|
|
None,
|
|
)
|
|
)
|
|
|
|
def set_shared_memory_config(self, shared_mem_bytes: int) -> None:
|
|
if shared_mem_bytes < 48 * 1024:
|
|
# No configuration needed for <= 48KB, just update the value
|
|
self._max_shared_mem_bytes = shared_mem_bytes
|
|
return
|
|
|
|
libcuda = _get_gpu_runtime_library()
|
|
|
|
# Get device properties to validate against limits
|
|
device_props = torch.cuda.get_device_properties()
|
|
# HIP doesn't have shared_memory_per_block_optin in device properties, so we hard-code it here
|
|
if torch.version.hip:
|
|
# navi, CDNA1-CDNA3 allows a max of 64KB shared memory
|
|
# CDNA4 allows a max of 160KB shared memory
|
|
max_shared_mem = (
|
|
65536 if device_props.gcnArchName != "gfx950" else 160 * 1024
|
|
)
|
|
else:
|
|
max_shared_mem = getattr(
|
|
device_props, "shared_memory_per_block_optin", 49152
|
|
)
|
|
|
|
if shared_mem_bytes > max_shared_mem:
|
|
raise RuntimeError(
|
|
f"Requested shared memory ({shared_mem_bytes} bytes) exceeds "
|
|
f"device limit ({max_shared_mem} bytes). "
|
|
"Consider reducing block size or shared memory usage."
|
|
)
|
|
|
|
# Set the function attribute once
|
|
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize = 8
|
|
_check_cuda(
|
|
libcuda.cuFuncSetAttribute(
|
|
self.func,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
shared_mem_bytes,
|
|
)
|
|
)
|
|
|
|
self._max_shared_mem_bytes = shared_mem_bytes
|
|
|
|
|
|
def _cuda_load_module(
|
|
ptx: Union[str, bytes], kernel_names: Optional[list[str]] = None
|
|
) -> Union[_CudaModule, dict[str, "_CudaKernel"]]:
|
|
"""
|
|
Loads a CUDA module from PTX code and returns a module object that can access kernels.
|
|
|
|
Args:
|
|
ptx (bytes or str): The PTX code to load
|
|
kernel_names (list, optional): List of kernel names to extract from the module.
|
|
If None, will return a module object with __getattr__.
|
|
|
|
Returns:
|
|
object: If kernel_names is None, returns a module object with __getattr__ to access kernels.
|
|
If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects.
|
|
"""
|
|
# Ensure CUDA is initialized
|
|
import torch.cuda
|
|
|
|
# Load CUDA driver library
|
|
libcuda = _get_gpu_runtime_library()
|
|
|
|
# Convert PTX to bytes if it's a string
|
|
if isinstance(ptx, str):
|
|
ptx = ptx.encode("utf-8")
|
|
|
|
# Load PTX module
|
|
module = ctypes.c_void_p()
|
|
# Get the current stream without directly importing torch.cuda at module level
|
|
stream = torch.cuda.current_stream()
|
|
with stream:
|
|
_check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx))
|
|
|
|
if not kernel_names:
|
|
return _CudaModule(module)
|
|
|
|
# Return specific kernels
|
|
kernels = {}
|
|
for name in kernel_names:
|
|
func = ctypes.c_void_p()
|
|
_check_cuda(
|
|
libcuda.cuModuleGetFunction(
|
|
ctypes.byref(func), module, name.encode("utf-8")
|
|
)
|
|
)
|
|
kernels[name] = _CudaKernel(func, module)
|
|
return kernels
|
|
|
|
|
|
def _get_device_index(
|
|
device: Any, optional: bool = False, allow_cpu: bool = False
|
|
) -> int:
|
|
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
|
|
|
|
If :attr:`device` is a torch.device object, returns the device index if it
|
|
is a CUDA device. Note that for a CUDA device without a specified index,
|
|
i.e., ``torch.device('cuda')``, this will return the current default CUDA
|
|
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
|
CPU devices will be accepted and ``-1`` will be returned in this case.
|
|
|
|
If :attr:`device` is a Python integer, it is returned as is.
|
|
|
|
If :attr:`device` is ``None``, this will return the current default CUDA
|
|
device if :attr:`optional` is ``True``.
|
|
"""
|
|
if isinstance(device, int):
|
|
return device
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
if isinstance(device, torch.device):
|
|
if allow_cpu:
|
|
if device.type not in ["cuda", "cpu"]:
|
|
raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
|
|
elif device.type != "cuda":
|
|
raise ValueError(f"Expected a cuda device, but got: {device}")
|
|
if not torch.jit.is_scripting():
|
|
if isinstance(device, torch.cuda.device):
|
|
return device.idx
|
|
return _torch_get_device_index(device, optional, allow_cpu)
|