Compare commits

...

14 Commits

Author SHA1 Message Date
e59a086a12 lint 2025-09-23 07:04:48 -07:00
c4d29f9ef5 refactor 2025-09-22 22:46:10 -07:00
8b0fca4dc6 update 2025-09-22 22:42:24 -07:00
225d61fc69 lint 2025-09-22 21:43:54 -07:00
37ccb79993 fix auto_pch call 2025-09-22 18:40:29 -07:00
fac59e4aab Merge branch 'main' into fix_nvrtc_discovery 2025-09-22 18:11:24 -07:00
60ec1a1f7a Update _utils.py 2025-09-22 18:01:07 -07:00
5ed4624298 Simplify NVRTC library loading in _get_nvrtc_library 2025-09-22 18:00:53 -07:00
3478f8d7e4 Remove nvrtc64_112_0.dll from Windows libraries 2025-09-22 17:11:49 -07:00
69b9a7f2e6 Update _utils.py 2025-09-22 17:11:32 -07:00
eaa4392429 allow users to specify nvrtc path 2025-09-16 07:43:05 -07:00
037f43ba45 add link 2025-09-10 21:47:06 -07:00
f39083423b lint 2025-09-10 21:44:55 -07:00
5e7e562cda update 2025-09-10 21:42:00 -07:00
2 changed files with 44 additions and 6 deletions

View File

@ -1735,6 +1735,8 @@ def _compile_kernel(
compute_capability: Optional[str] = None,
cuda_include_dirs: Optional[list] = None,
nvcc_options: Optional[list] = None,
rtc_path: Optional[str] = None,
auto_pch: bool = False,
):
"""
Compiles a CUDA kernel using NVRTC and returns a callable function.
@ -1751,6 +1753,9 @@ def _compile_kernel(
If None, will detect from current device.
cuda_include_dirs (list, optional): List of directories containing CUDA headers
nvcc_options (list, optional): Additional options to pass to NVRTC
rtc_path (str, optional): Path to the RTC library (NVRTC/HIPRTC). If provided, this will skip the
automatic discovery logic and use the specified library directly.
auto_pch (bool, optional): Whether to automatically use precompiled headers. Default is False.
Returns:
callable: A Python function that can be called with PyTorch tensor arguments to execute the kernel
@ -1780,6 +1785,8 @@ def _compile_kernel(
compute_capability,
cuda_include_dirs,
nvcc_options,
rtc_path,
auto_pch,
)
# Load the module and get the kernel

View File

@ -72,14 +72,36 @@ def _get_hiprtc_library() -> ctypes.CDLL:
def _get_nvrtc_library() -> ctypes.CDLL:
if sys.platform == "win32":
return ctypes.CDLL("nvrtc64_120_0.dll")
nvrtc_libs = [
"nvrtc64_130_0.dll",
"nvrtc64_120_0.dll",
]
else:
return ctypes.CDLL("libnvrtc.so")
nvrtc_libs = [
"libnvrtc.so.13",
"libnvrtc.so.12",
"libnvrtc.so",
]
for lib_name in nvrtc_libs:
try:
return ctypes.CDLL(lib_name)
except OSError:
continue
raise OSError("Could not find any NVRTC library")
def _get_gpu_rtc_library() -> ctypes.CDLL:
# Since PyTorch already loads the GPU RTC library, we can use the system library
# which should be compatible with PyTorch's version
def _get_gpu_rtc_library(rtc_path: Optional[str] = None) -> ctypes.CDLL:
# If custom path provided, use it directly
if rtc_path:
try:
return ctypes.CDLL(rtc_path)
except OSError:
(f"Could not load RTC library from specified path {rtc_path}")
# Otherwise use auto-discovery based on GPU vendor
if torch.version.hip:
return _get_hiprtc_library()
else:
@ -116,6 +138,7 @@ def _nvrtc_compile(
compute_capability: Optional[str] = None,
cuda_include_dirs: Optional[list] = None,
nvcc_options: Optional[list] = None,
rtc_path: Optional[str] = None,
auto_pch: bool = False,
) -> tuple[bytes, str]:
"""
@ -128,6 +151,8 @@ def _nvrtc_compile(
If None, will detect from current device.
cuda_include_dirs (list, None): List of directories containing CUDA headers
nvcc_options (list, None): Additional options to pass to NVRTC
rtc_path (str, optional): Path to the RTC library (NVRTC/HIPRTC). If provided, this will skip the
automatic discovery logic and use the specified library directly.
auto_pch (bool): Enable automatic precompiled headers (CUDA 12.8+)
Returns:
@ -136,8 +161,11 @@ def _nvrtc_compile(
# Ensure CUDA is initialized
import torch.cuda
if not torch.cuda.is_initialized():
torch.cuda.init()
# Load NVRTC library
libnvrtc = _get_gpu_rtc_library()
libnvrtc = _get_gpu_rtc_library(rtc_path)
# NVRTC constants
NVRTC_SUCCESS = 0
@ -454,6 +482,9 @@ def _cuda_load_module(
# Ensure CUDA is initialized
import torch.cuda
if not torch.cuda.is_initialized():
torch.cuda.init()
# Load CUDA driver library
libcuda = _get_gpu_runtime_library()