mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve cuDNN detection at build time
This commit is contained in:
10
setup.py
10
setup.py
@ -9,10 +9,10 @@ import shutil
|
||||
import sys
|
||||
import os
|
||||
|
||||
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
|
||||
WITH_CUDA = os.path.exists(CUDA_HOME)
|
||||
WITH_CUDNN = WITH_CUDA
|
||||
DEBUG = False
|
||||
from tools.setup_helpers.env import check_env_flag
|
||||
from tools.setup_helpers.cuda import WITH_CUDA, CUDA_HOME
|
||||
from tools.setup_helpers.cudnn import WITH_CUDNN, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR
|
||||
DEBUG = check_env_flag('DEBUG')
|
||||
|
||||
################################################################################
|
||||
# Monkey-patch setuptools to compile in parallel
|
||||
@ -200,6 +200,8 @@ if WITH_CUDA:
|
||||
|
||||
if WITH_CUDNN:
|
||||
main_libraries += ['cudnn']
|
||||
include_dirs.append(CUDNN_INCLUDE_DIR)
|
||||
extra_link_args.append('-L' + CUDNN_LIB_DIR)
|
||||
main_sources += [
|
||||
"torch/csrc/cudnn/Module.cpp",
|
||||
"torch/csrc/cudnn/Conv.cpp",
|
||||
|
||||
0
tools/setup_helpers/__init__.py
Normal file
0
tools/setup_helpers/__init__.py
Normal file
8
tools/setup_helpers/cuda.py
Normal file
8
tools/setup_helpers/cuda.py
Normal file
@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
from .env import check_env_flag
|
||||
|
||||
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
|
||||
WITH_CUDA = not check_env_flag('NO_CUDA') and os.path.exists(CUDA_HOME)
|
||||
if not WITH_CUDA:
|
||||
CUDA_HOME = None
|
||||
35
tools/setup_helpers/cudnn.py
Normal file
35
tools/setup_helpers/cudnn.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
import glob
|
||||
|
||||
from .env import check_env_flag
|
||||
from .cuda import WITH_CUDA, CUDA_HOME
|
||||
|
||||
WITH_CUDNN = False
|
||||
CUDNN_LIB_DIR = None
|
||||
CUDNN_INCLUDE_DIR = None
|
||||
if WITH_CUDA and not check_env_flag('WITH_CUDNN'):
|
||||
lib_paths = list(filter(bool, [
|
||||
os.getenv('CUDNN_LIB_DIR'),
|
||||
os.path.join(CUDA_HOME, 'lib'),
|
||||
os.path.join(CUDA_HOME, 'lib64')
|
||||
]))
|
||||
include_paths = list(filter(bool, [
|
||||
os.getenv('CUDNN_INCLUDE_DIR'),
|
||||
os.path.join(CUDA_HOME, 'include'),
|
||||
]))
|
||||
for path in lib_paths:
|
||||
if path is None or not os.path.exists(path):
|
||||
continue
|
||||
if glob.glob(os.path.join(path, 'libcudnn*')):
|
||||
CUDNN_LIB_DIR = path
|
||||
break
|
||||
for path in include_paths:
|
||||
if path is None or not os.path.exists(path):
|
||||
continue
|
||||
if os.path.exists((os.path.join(path, 'cudnn.h'))):
|
||||
CUDNN_INCLUDE_DIR = path
|
||||
break
|
||||
if not CUDNN_LIB_DIR or not CUDNN_INCLUDE_DIR:
|
||||
CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
|
||||
else:
|
||||
WITH_CUDNN = True
|
||||
4
tools/setup_helpers/env.py
Normal file
4
tools/setup_helpers/env.py
Normal file
@ -0,0 +1,4 @@
|
||||
import os
|
||||
|
||||
def check_env_flag(name):
|
||||
return os.getenv(name) in {'ON', '1', 'YES', 'TRUE', 'Y'}
|
||||
Reference in New Issue
Block a user