[fft][1 of 3] build system and helpers to support cuFFT and MKL (#5855)

This is the first of three PRs that #5537 will be split into.

This PR adds mkl headers to included files, and provides helper functions for MKL fft and cuFFT.
In particular, on POSIX, headers are using mkl-include from conda, and on Windows, it is from a new file @yf225 and I made and uploaded to s3.

* add mkl-include to required packages

* include MKL headers; add AT_MKL_ENABLED flag; add a method to query MKL availability

* Add MKL and CUFFT helpers
This commit is contained in:
Tongzhou Wang
2018-03-19 15:43:14 -04:00
committed by Edward Z. Yang
parent d11b7fbd1c
commit 22ef8e5654
20 changed files with 219 additions and 19 deletions

View File

@ -9,7 +9,7 @@ rm -rf $PWD/miniconda3
bash $PWD/miniconda3.sh -b -p $PWD/miniconda3
export PATH="$PWD/miniconda3/bin:$PATH"
source $PWD/miniconda3/bin/activate
conda install -y numpy pyyaml setuptools cmake cffi ninja
conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja
# Build and test PyTorch
git submodule update --init --recursive

View File

@ -32,8 +32,9 @@ cat >ci_scripts/build_pytorch.bat <<EOL
set PATH=C:\\Program Files\\CMake\\bin;C:\\Program Files\\7-Zip;C:\\curl-7.57.0-win64-mingw\\bin;C:\\Program Files\\Git\\cmd;C:\\Program Files\\Amazon\\AWSCLI;%PATH%
:: Install MKL
aws s3 cp s3://ossci-windows/mkl.7z mkl.7z --quiet && 7z x -aoa mkl.7z -omkl
set LIB=%cd%\\mkl;%LIB%
aws s3 cp s3://ossci-windows/mkl_with_headers.7z mkl.7z --quiet && 7z x -aoa mkl.7z -omkl
set CMAKE_INCLUDE_PATH=%cd%\\mkl\\include
set LIB=%cd%\\mkl\\lib;%LIB
:: Install MAGMA
aws s3 cp s3://ossci-windows/magma_cuda90_release.7z magma_cuda90_release.7z --quiet && 7z x -aoa magma_cuda90_release.7z -omagma_cuda90_release
@ -47,7 +48,7 @@ IF EXIST C:\\Jenkins\\Miniconda3 ( rd /s /q C:\\Jenkins\\Miniconda3 )
curl https://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe -O
.\Miniconda3-latest-Windows-x86_64.exe /InstallationType=JustMe /RegisterPython=0 /S /AddToPath=0 /D=C:\\Jenkins\\Miniconda3
call C:\\Jenkins\\Miniconda3\\Scripts\\activate.bat C:\\Jenkins\\Miniconda3
call conda install -y -q numpy mkl cffi pyyaml boto3
call conda install -y -q numpy cffi pyyaml boto3
:: Install ninja
pip install ninja

View File

@ -14,11 +14,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install numpy pyyaml scipy ipython mkl && \
/opt/conda/bin/conda install numpy pyyaml scipy ipython mkl mkl-include && \
/opt/conda/bin/conda install -c soumith magma-cuda90 && \
/opt/conda/bin/conda clean -ya
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH
# This must be done before pip so that requirements.txt is available
WORKDIR /opt/pytorch

View File

@ -171,7 +171,7 @@ On Linux
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]
# Install basic dependencies
conda install numpy pyyaml mkl setuptools cmake cffi typing
conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
# Add LAPACK support for the GPU
conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9

View File

@ -380,10 +380,13 @@ MACRO(Install_Required_Library ln)
ENDMACRO(Install_Required_Library libname)
FIND_PACKAGE(BLAS)
SET(AT_MKL_ENABLED 0)
IF(BLAS_FOUND)
SET(USE_BLAS 1)
IF(BLAS_INFO STREQUAL "mkl")
ADD_DEFINITIONS(-DTH_BLAS_MKL)
INCLUDE_DIRECTORIES(${BLAS_INCLUDE_DIR}) # include MKL headers
SET(AT_MKL_ENABLED 1)
ENDIF()
ENDIF(BLAS_FOUND)

View File

@ -128,12 +128,14 @@ FILE(GLOB base_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
FILE(GLOB native_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/*.cpp")
FILE(GLOB native_cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cudnn/*.cpp")
FILE(GLOB native_cuda_cu RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cuda/*.cu")
FILE(GLOB native_mkl_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/mkl/*.cpp")
FILE(GLOB_RECURSE cuda_h
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"cuda/*.cuh" "cuda/*.h" "cudnn/*.cuh" "cudnn/*.h")
FILE(GLOB cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cudnn/*.cpp")
FILE(GLOB mkl_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "mkl/*.cpp")
FILE(GLOB all_python RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
@ -179,8 +181,7 @@ ADD_CUSTOM_TARGET(aten_files_are_generated
)
SET(all_cpp ${base_cpp} ${native_cpp} ${native_cudnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp})
SET(all_cpp ${base_cpp} ${native_cpp} ${native_cudnn_cpp} ${native_mkl_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp})
INCLUDE_DIRECTORIES(${ATen_CPU_INCLUDE})
IF(NOT NO_CUDA)
@ -192,6 +193,9 @@ IF(NOT NO_CUDA)
IF(CUDNN_FOUND)
SET(all_cpp ${all_cpp} ${cudnn_cpp})
ENDIF()
IF(AT_MKL_ENABLED)
SET(all_cpp ${all_cpp} ${mkl_cpp})
ENDIF()
endif()
filter_list(generated_h generated_cpp "\\.h$")
@ -309,6 +313,7 @@ IF(CUDA_FOUND)
${CUDA_cusparse_LIBRARY}
${CUDA_curand_LIBRARY})
CUDA_ADD_CUBLAS_TO_TARGET(ATen)
CUDA_ADD_CUFFT_TO_TARGET(ATen)
if(CUDNN_FOUND)
target_link_libraries(ATen ${CUDNN_LIBRARIES})

View File

@ -1,12 +1,13 @@
#pragma once
// Test these using #if AT_CUDA_ENABLED()(), not #ifdef, so that it's
// Test these using #if AT_CUDA_ENABLED(), not #ifdef, so that it's
// obvious if you forgot to include Config.h
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
#define AT_CUDA_ENABLED() @AT_CUDA_ENABLED@
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
#if !AT_CUDA_ENABLED() && AT_CUDNN_ENABLED()
#error "Cannot enable CuDNN without CUDA"

View File

@ -89,6 +89,14 @@ void Context::setBenchmarkCuDNN(bool b) {
benchmark_cudnn = b;
}
bool Context::hasMKL() const {
#if AT_MKL_ENABLED()
return true;
#else
return false;
#endif
}
bool Context::hasCUDA() const {
#if AT_CUDA_ENABLED()
int count;

View File

@ -43,6 +43,7 @@ public:
runtime_error("%s backend type not enabled.",toString(p));
return *generator;
}
bool hasMKL() const;
bool hasCUDA() const;
int64_t current_device() const;
// defined in header so that getType has ability to inline
@ -103,7 +104,7 @@ static inline void init() {
}
static inline Type& getType(Backend p, ScalarType s) {
return globalContext().getType(p,s);
return globalContext().getType(p, s);
}
static inline Type& CPU(ScalarType s) {
@ -118,6 +119,10 @@ static inline bool hasCUDA() {
return globalContext().hasCUDA();
}
static inline bool hasMKL() {
return globalContext().hasMKL();
}
static inline int64_t current_device() {
return globalContext().current_device();
}

View File

@ -0,0 +1,44 @@
#pragma once
#include "Exceptions.h"
#include <mkl_dfti.h>
#include <ATen/Tensor.h>
namespace at { namespace native {
struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR* desc) {
if (desc != nullptr) {
MKL_DFTI_CHECK(DftiFreeDescriptor(&desc));
}
}
};
class DftiDescriptor {
public:
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) {
if (desc_ != nullptr) {
throw std::runtime_error("DFTI DESCRIPTOR can only be initialized once");
}
DFTI_DESCRIPTOR *raw_desc;
if (signal_ndim == 1) {
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
} else {
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, signal_ndim, sizes));
}
desc_.reset(raw_desc);
}
DFTI_DESCRIPTOR *get() const {
if (desc_ == nullptr) {
throw std::runtime_error("DFTI DESCRIPTOR has not been initialized");
}
return desc_.get();
}
private:
std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
};
}} // at::native

View File

@ -0,0 +1,19 @@
#pragma once
#include <string>
#include <stdexcept>
#include <sstream>
#include <mkl_dfti.h>
namespace at { namespace native {
static inline void MKL_DFTI_CHECK(MKL_INT status)
{
if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
std::ostringstream ss;
ss << "MKL FFT error: " << DftiErrorMessage(status);
throw std::runtime_error(ss.str());
}
}
}} // namespace at::native

View File

@ -0,0 +1,11 @@
#pragma once
#include <mkl_types.h>
namespace at { namespace native {
// Since size of MKL_LONG varies on different platforms (linux 64 bit, windows
// 32 bit), we need to programmatically calculate the max.
static int64_t MKL_LONG_MAX = ((1LL << (sizeof(MKL_LONG) * 8 - 2)) - 1) * 2 + 1;
}} // namespace

View File

@ -0,0 +1,4 @@
All files living in this directory are written with the assumption that MKL is available,
which means that these code are not guarded by `#if AT_MKL_ENABLED()`. Therefore, whenever
you need to use definitions from here, please guard the `#include<ATen/mkl/*.h>` and
definition usages with `#if AT_MKL_ENABLED()` macro, e.g. [SpectralOps.cpp](native/mkl/SpectralOps.cpp).

View File

@ -0,0 +1,84 @@
#pragma once
#include "ATen/ATen.h"
#include "ATen/Config.h"
#include <string>
#include <stdexcept>
#include <sstream>
#include <cufft.h>
#include <cufftXt.h>
namespace at { namespace native {
static inline std::string _cudaGetErrorEnum(cufftResult error)
{
switch (error)
{
case CUFFT_SUCCESS:
return "CUFFT_SUCCESS";
case CUFFT_INVALID_PLAN:
return "CUFFT_INVALID_PLAN";
case CUFFT_ALLOC_FAILED:
return "CUFFT_ALLOC_FAILED";
case CUFFT_INVALID_TYPE:
return "CUFFT_INVALID_TYPE";
case CUFFT_INVALID_VALUE:
return "CUFFT_INVALID_VALUE";
case CUFFT_INTERNAL_ERROR:
return "CUFFT_INTERNAL_ERROR";
case CUFFT_EXEC_FAILED:
return "CUFFT_EXEC_FAILED";
case CUFFT_SETUP_FAILED:
return "CUFFT_SETUP_FAILED";
case CUFFT_INVALID_SIZE:
return "CUFFT_INVALID_SIZE";
case CUFFT_UNALIGNED_DATA:
return "CUFFT_UNALIGNED_DATA";
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_INVALID_DEVICE:
return "CUFFT_INVALID_DEVICE";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
case CUFFT_NO_WORKSPACE:
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
case CUFFT_NOT_SUPPORTED:
return "CUFFT_NOT_SUPPORTED";
default:
std::ostringstream ss;
ss << "unknown error " << error;
return ss.str();
}
}
static inline void CUFFT_CHECK(cufftResult error)
{
if (error != CUFFT_SUCCESS) {
std::ostringstream ss;
ss << "cuFFT error: " << _cudaGetErrorEnum(error);
throw std::runtime_error(ss.str());
}
}
class CufftHandle {
public:
explicit CufftHandle() {
CUFFT_CHECK(cufftCreate(&raw_plan));
}
const cufftHandle &get() const { return raw_plan; }
~CufftHandle() {
CUFFT_CHECK(cufftDestroy(raw_plan));
}
private:
cufftHandle raw_plan;
};
}} // at::native

View File

@ -8,4 +8,8 @@
#error "AT_CUDNN_ENABLED should not be visible in public headers"
#endif
#ifdef AT_MKL_ENABLED
#error "AT_MKL_ENABLED should not be visible in public headers"
#endif
auto main() -> int {}

View File

@ -180,7 +180,7 @@ IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED)
ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED)
IF(NOT LAPACK_FIND_QUIETLY)
IF(LAPACK_FOUND)
MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})")
MESSAGE(STATUS "Found a library with LAPACK API (${LAPACK_INFO}).")
ELSE(LAPACK_FOUND)
MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.")
ENDIF(LAPACK_FOUND)

View File

@ -175,7 +175,7 @@ FOREACH(mklrtl ${mklrtls} "")
IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "")
ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
ENDFOREACH(mklthread)
ENDFOREACH(mkl64)
ENDFOREACH(mkliface)
@ -200,7 +200,7 @@ FOREACH(mklrtl ${mklrtls} "")
IF (NOT MKL_LIBRARIES)
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "")
ENDIF (NOT MKL_LIBRARIES)
ENDIF (NOT MKL_LIBRARIES)
ENDFOREACH(mklthread)
ENDFOREACH(mkl64)
ENDFOREACH(mkliface)
@ -211,7 +211,7 @@ IF (NOT MKL_LIBRARIES)
SET(MKL_VERSION 900)
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
"mkl;guide;pthread;m" "")
ENDIF (NOT MKL_LIBRARIES)
ENDIF (NOT MKL_LIBRARIES)
# Include files
IF (MKL_LIBRARIES)
@ -228,7 +228,7 @@ IF (MKL_LIBRARIES)
MARK_AS_ADVANCED(MKL_LAPACK_LIBRARIES)
ENDIF (NOT MKL_LAPACK_LIBRARIES)
IF (NOT MKL_SCALAPACK_LIBRARIES)
FIND_LIBRARY(MKL_SCALAPACK_LIBRARIES NAMES "mkl_scalapack${mkl64}${mkls}")
FIND_LIBRARY(MKL_SCALAPACK_LIBRARIES NAMES "mkl_scalapack${mkl64}${mkls}")
MARK_AS_ADVANCED(MKL_SCALAPACK_LIBRARIES)
ENDIF (NOT MKL_SCALAPACK_LIBRARIES)
IF (NOT MKL_SOLVER_LIBRARIES)
@ -243,7 +243,7 @@ IF (MKL_LIBRARIES)
ENDFOREACH(mkl64)
ENDIF (MKL_LIBRARIES)
# LibIRC: intel compiler always links this;
# LibIRC: intel compiler always links this;
# gcc does not; but mkl kernels sometimes need it.
IF (MKL_LIBRARIES)
IF (CMAKE_COMPILER_IS_GNUCC)
@ -269,7 +269,7 @@ ENDIF (MKL_LIBRARIES)
# Standard termination
IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location")
MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location")
ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
IF(NOT MKL_FIND_QUIETLY)
IF(MKL_FOUND)

View File

@ -19,6 +19,7 @@ import torch.cuda
from torch.autograd import Variable
from torch._six import string_classes
import torch.backends.cudnn
import torch.backends.mkl
torch.set_default_tensor_type('torch.DoubleTensor')
@ -54,6 +55,8 @@ try:
except ImportError:
TEST_SCIPY = False
TEST_MKL = torch.backends.mkl.is_available()
def skipIfNoLapack(fn):
@wraps(fn)

View File

@ -0,0 +1,6 @@
import torch
def is_available():
r"""Returns whether PyTorch is built with MKL support."""
return torch._C.has_mkl

View File

@ -505,6 +505,8 @@ static PyObject* initModule() {
// setting up TH Errors so that they throw C++ exceptions
at::init();
ASSERT_TRUE(PyModule_AddObject(module, "has_mkl", at::hasMKL() ? Py_True : Py_False) == 0);
auto& defaultGenerator = at::globalContext().defaultGenerator(at::kCPU);
THPDefaultGenerator = (THPGenerator*)THPGenerator_NewWithGenerator(
defaultGenerator);