[CD] Fix slim-wheel nvjit-link import problem (#141063)

When other toolkit (say CUDA-12.3)  is installed and `LD_LIBRARY_PATH` points to there, import torch will fail with
```
ImportError: /usr/local/lib/python3.10/dist-packages/torch/lib/../../nvidia/cusparse/lib/libcusparse.so.12: undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12
```
It could not be worked around by tweaking rpath, as it also depends on the library load order, which are not guaranteed by any linker. Instead solve this by preloading `nvjitlink` right after global deps are loaded, by running something along the lines of the following
```python
        if version.cuda in ["12.4", "12.6"]:
            with open("/proc/self/maps") as f:
                _maps = f.read()
            # libtorch_global_deps.so always depends in cudart, check if its installed via wheel
            if "nvidia/cuda_runtime/lib/libcudart.so" in _maps:
                # If all abovementioned conditions are met, preload nvjitlink
                _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]")
```

Fixes https://github.com/pytorch/pytorch/issues/140797

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141063
Approved by: https://github.com/kit1980

Co-authored-by: Sergii Dymchenko <sdym@meta.com>
This commit is contained in:
Nikita Shulga
2025-01-14 17:33:07 +00:00
committed by PyTorch MergeBot
parent 5c727d5679
commit f2975717f3

View File

@ -316,6 +316,24 @@ def _load_global_deps() -> None:
try:
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
# Workaround slim-wheel CUDA-12.4+ dependency bug in libcusparse by preloading nvjitlink
# In those versions of cuda cusparse depends on nvjitlink, but does not have rpath when
# shipped as wheel, which results in OS picking wrong/older version of nvjitlink library
# if `LD_LIBRARY_PATH` is defined
# See https://github.com/pytorch/pytorch/issues/138460
if version.cuda not in ["12.4", "12.6"]: # type: ignore[name-defined]
return
try:
with open("/proc/self/maps") as f:
_maps = f.read()
# libtorch_global_deps.so always depends in cudart, check if its installed via wheel
if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps:
return
# If all abovementioned conditions are met, preload nvjitlink
_preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]")
except Exception:
pass
except OSError as err:
# Can only happen for wheel with cuda libs as PYPI deps
# As PyTorch is not purelib, but nvidia-*-cu12 is