Improve cuDNN detection at build time

This commit is contained in:
Adam Paszke
2016-12-01 19:50:16 +01:00
parent 1f5951693a
commit cb849524f3
5 changed files with 53 additions and 4 deletions

View File

@ -9,10 +9,10 @@ import shutil
import sys
import os
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
WITH_CUDA = os.path.exists(CUDA_HOME)
WITH_CUDNN = WITH_CUDA
DEBUG = False
from tools.setup_helpers.env import check_env_flag
from tools.setup_helpers.cuda import WITH_CUDA, CUDA_HOME
from tools.setup_helpers.cudnn import WITH_CUDNN, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR
DEBUG = check_env_flag('DEBUG')
################################################################################
# Monkey-patch setuptools to compile in parallel
@ -200,6 +200,8 @@ if WITH_CUDA:
if WITH_CUDNN:
main_libraries += ['cudnn']
include_dirs.append(CUDNN_INCLUDE_DIR)
extra_link_args.append('-L' + CUDNN_LIB_DIR)
main_sources += [
"torch/csrc/cudnn/Module.cpp",
"torch/csrc/cudnn/Conv.cpp",

View File

View File

@ -0,0 +1,8 @@
import os
from .env import check_env_flag
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
WITH_CUDA = not check_env_flag('NO_CUDA') and os.path.exists(CUDA_HOME)
if not WITH_CUDA:
CUDA_HOME = None

View File

@ -0,0 +1,35 @@
import os
import glob
from .env import check_env_flag
from .cuda import WITH_CUDA, CUDA_HOME
WITH_CUDNN = False
CUDNN_LIB_DIR = None
CUDNN_INCLUDE_DIR = None
if WITH_CUDA and not check_env_flag('WITH_CUDNN'):
lib_paths = list(filter(bool, [
os.getenv('CUDNN_LIB_DIR'),
os.path.join(CUDA_HOME, 'lib'),
os.path.join(CUDA_HOME, 'lib64')
]))
include_paths = list(filter(bool, [
os.getenv('CUDNN_INCLUDE_DIR'),
os.path.join(CUDA_HOME, 'include'),
]))
for path in lib_paths:
if path is None or not os.path.exists(path):
continue
if glob.glob(os.path.join(path, 'libcudnn*')):
CUDNN_LIB_DIR = path
break
for path in include_paths:
if path is None or not os.path.exists(path):
continue
if os.path.exists((os.path.join(path, 'cudnn.h'))):
CUDNN_INCLUDE_DIR = path
break
if not CUDNN_LIB_DIR or not CUDNN_INCLUDE_DIR:
CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
else:
WITH_CUDNN = True

View File

@ -0,0 +1,4 @@
import os
def check_env_flag(name):
return os.getenv(name) in {'ON', '1', 'YES', 'TRUE', 'Y'}