mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add cuda 11.8 guard for cufile preload (#148184)
Follow up after https://github.com/pytorch/pytorch/pull/148137 Make sure we don't try to load cufile on CUDA 11.8 Test: ``` >>> import torch /usr/local/lib64/python3.9/site-packages/torch/_subclasses/functional_tensor.py:276: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:81.) cpu = _conversion_method_template(device=torch.device("cpu")) >>> torch.__version__ '2.7.0.dev20250227+cu118' >>> ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148184 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
2544afaa1a
commit
230a3b0f83
@ -334,6 +334,8 @@ def _load_global_deps() -> None:
|
||||
except OSError as err:
|
||||
# Can only happen for wheel with cuda libs as PYPI deps
|
||||
# As PyTorch is not purelib, but nvidia-*-cu12 is
|
||||
from torch.version import cuda as cuda_version
|
||||
|
||||
cuda_libs: dict[str, str] = {
|
||||
"cublas": "libcublas.so.*[0-9]",
|
||||
"cudnn": "libcudnn.so.*[0-9]",
|
||||
@ -343,13 +345,20 @@ def _load_global_deps() -> None:
|
||||
"cufft": "libcufft.so.*[0-9]",
|
||||
"curand": "libcurand.so.*[0-9]",
|
||||
"nvjitlink": "libnvJitLink.so.*[0-9]",
|
||||
"cufile": "libcufile.so.*[0-9]",
|
||||
"cusparse": "libcusparse.so.*[0-9]",
|
||||
"cusparselt": "libcusparseLt.so.*[0-9]",
|
||||
"cusolver": "libcusolver.so.*[0-9]",
|
||||
"nccl": "libnccl.so.*[0-9]",
|
||||
"nvtx": "libnvToolsExt.so.*[0-9]",
|
||||
}
|
||||
# cufiile is only available on cuda 12+
|
||||
# TODO: Remove once CUDA 11.8 binaries are deprecated
|
||||
if cuda_version is not None:
|
||||
t_version = cuda_version.split(".")
|
||||
t_major = int(t_version[0]) # type: ignore[operator]
|
||||
if t_major >= 12:
|
||||
cuda_libs["cufile"] = "libcufile.so.*[0-9]"
|
||||
|
||||
is_cuda_lib_err = [
|
||||
lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0]
|
||||
]
|
||||
|
Reference in New Issue
Block a user