mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CD] Fix slim-wheel nvjit-link import problem (#141063)
When other toolkit (say CUDA-12.3) is installed and `LD_LIBRARY_PATH` points to there, import torch will fail with ``` ImportError: /usr/local/lib/python3.10/dist-packages/torch/lib/../../nvidia/cusparse/lib/libcusparse.so.12: undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12 ``` It could not be worked around by tweaking rpath, as it also depends on the library load order, which are not guaranteed by any linker. Instead solve this by preloading `nvjitlink` right after global deps are loaded, by running something along the lines of the following ```python if version.cuda in ["12.4", "12.6"]: with open("/proc/self/maps") as f: _maps = f.read() # libtorch_global_deps.so always depends in cudart, check if its installed via wheel if "nvidia/cuda_runtime/lib/libcudart.so" in _maps: # If all abovementioned conditions are met, preload nvjitlink _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]") ``` Fixes https://github.com/pytorch/pytorch/issues/140797 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141063 Approved by: https://github.com/kit1980 Co-authored-by: Sergii Dymchenko <sdym@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c727d5679
commit
f2975717f3
@ -316,6 +316,24 @@ def _load_global_deps() -> None:
|
||||
|
||||
try:
|
||||
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
|
||||
# Workaround slim-wheel CUDA-12.4+ dependency bug in libcusparse by preloading nvjitlink
|
||||
# In those versions of cuda cusparse depends on nvjitlink, but does not have rpath when
|
||||
# shipped as wheel, which results in OS picking wrong/older version of nvjitlink library
|
||||
# if `LD_LIBRARY_PATH` is defined
|
||||
# See https://github.com/pytorch/pytorch/issues/138460
|
||||
if version.cuda not in ["12.4", "12.6"]: # type: ignore[name-defined]
|
||||
return
|
||||
try:
|
||||
with open("/proc/self/maps") as f:
|
||||
_maps = f.read()
|
||||
# libtorch_global_deps.so always depends in cudart, check if its installed via wheel
|
||||
if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps:
|
||||
return
|
||||
# If all abovementioned conditions are met, preload nvjitlink
|
||||
_preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except OSError as err:
|
||||
# Can only happen for wheel with cuda libs as PYPI deps
|
||||
# As PyTorch is not purelib, but nvidia-*-cu12 is
|
||||
|
Reference in New Issue
Block a user