mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fft][1 of 3] build system and helpers to support cuFFT and MKL (#5855)
This is the first of three PRs that #5537 will be split into. This PR adds mkl headers to included files, and provides helper functions for MKL fft and cuFFT. In particular, on POSIX, headers are using mkl-include from conda, and on Windows, it is from a new file @yf225 and I made and uploaded to s3. * add mkl-include to required packages * include MKL headers; add AT_MKL_ENABLED flag; add a method to query MKL availability * Add MKL and CUFFT helpers
This commit is contained in:
committed by
Edward Z. Yang
parent
d11b7fbd1c
commit
22ef8e5654
@ -9,7 +9,7 @@ rm -rf $PWD/miniconda3
|
|||||||
bash $PWD/miniconda3.sh -b -p $PWD/miniconda3
|
bash $PWD/miniconda3.sh -b -p $PWD/miniconda3
|
||||||
export PATH="$PWD/miniconda3/bin:$PATH"
|
export PATH="$PWD/miniconda3/bin:$PATH"
|
||||||
source $PWD/miniconda3/bin/activate
|
source $PWD/miniconda3/bin/activate
|
||||||
conda install -y numpy pyyaml setuptools cmake cffi ninja
|
conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja
|
||||||
|
|
||||||
# Build and test PyTorch
|
# Build and test PyTorch
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
|
@ -32,8 +32,9 @@ cat >ci_scripts/build_pytorch.bat <<EOL
|
|||||||
set PATH=C:\\Program Files\\CMake\\bin;C:\\Program Files\\7-Zip;C:\\curl-7.57.0-win64-mingw\\bin;C:\\Program Files\\Git\\cmd;C:\\Program Files\\Amazon\\AWSCLI;%PATH%
|
set PATH=C:\\Program Files\\CMake\\bin;C:\\Program Files\\7-Zip;C:\\curl-7.57.0-win64-mingw\\bin;C:\\Program Files\\Git\\cmd;C:\\Program Files\\Amazon\\AWSCLI;%PATH%
|
||||||
|
|
||||||
:: Install MKL
|
:: Install MKL
|
||||||
aws s3 cp s3://ossci-windows/mkl.7z mkl.7z --quiet && 7z x -aoa mkl.7z -omkl
|
aws s3 cp s3://ossci-windows/mkl_with_headers.7z mkl.7z --quiet && 7z x -aoa mkl.7z -omkl
|
||||||
set LIB=%cd%\\mkl;%LIB%
|
set CMAKE_INCLUDE_PATH=%cd%\\mkl\\include
|
||||||
|
set LIB=%cd%\\mkl\\lib;%LIB
|
||||||
|
|
||||||
:: Install MAGMA
|
:: Install MAGMA
|
||||||
aws s3 cp s3://ossci-windows/magma_cuda90_release.7z magma_cuda90_release.7z --quiet && 7z x -aoa magma_cuda90_release.7z -omagma_cuda90_release
|
aws s3 cp s3://ossci-windows/magma_cuda90_release.7z magma_cuda90_release.7z --quiet && 7z x -aoa magma_cuda90_release.7z -omagma_cuda90_release
|
||||||
@ -47,7 +48,7 @@ IF EXIST C:\\Jenkins\\Miniconda3 ( rd /s /q C:\\Jenkins\\Miniconda3 )
|
|||||||
curl https://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe -O
|
curl https://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe -O
|
||||||
.\Miniconda3-latest-Windows-x86_64.exe /InstallationType=JustMe /RegisterPython=0 /S /AddToPath=0 /D=C:\\Jenkins\\Miniconda3
|
.\Miniconda3-latest-Windows-x86_64.exe /InstallationType=JustMe /RegisterPython=0 /S /AddToPath=0 /D=C:\\Jenkins\\Miniconda3
|
||||||
call C:\\Jenkins\\Miniconda3\\Scripts\\activate.bat C:\\Jenkins\\Miniconda3
|
call C:\\Jenkins\\Miniconda3\\Scripts\\activate.bat C:\\Jenkins\\Miniconda3
|
||||||
call conda install -y -q numpy mkl cffi pyyaml boto3
|
call conda install -y -q numpy cffi pyyaml boto3
|
||||||
|
|
||||||
:: Install ninja
|
:: Install ninja
|
||||||
pip install ninja
|
pip install ninja
|
||||||
|
@ -14,11 +14,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
|
|
||||||
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||||
chmod +x ~/miniconda.sh && \
|
chmod +x ~/miniconda.sh && \
|
||||||
~/miniconda.sh -b -p /opt/conda && \
|
~/miniconda.sh -b -p /opt/conda && \
|
||||||
rm ~/miniconda.sh && \
|
rm ~/miniconda.sh && \
|
||||||
/opt/conda/bin/conda install numpy pyyaml scipy ipython mkl && \
|
/opt/conda/bin/conda install numpy pyyaml scipy ipython mkl mkl-include && \
|
||||||
/opt/conda/bin/conda install -c soumith magma-cuda90 && \
|
/opt/conda/bin/conda install -c soumith magma-cuda90 && \
|
||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
# This must be done before pip so that requirements.txt is available
|
# This must be done before pip so that requirements.txt is available
|
||||||
WORKDIR /opt/pytorch
|
WORKDIR /opt/pytorch
|
||||||
|
@ -171,7 +171,7 @@ On Linux
|
|||||||
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]
|
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]
|
||||||
|
|
||||||
# Install basic dependencies
|
# Install basic dependencies
|
||||||
conda install numpy pyyaml mkl setuptools cmake cffi typing
|
conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
|
||||||
|
|
||||||
# Add LAPACK support for the GPU
|
# Add LAPACK support for the GPU
|
||||||
conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9
|
conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9
|
||||||
|
@ -380,10 +380,13 @@ MACRO(Install_Required_Library ln)
|
|||||||
ENDMACRO(Install_Required_Library libname)
|
ENDMACRO(Install_Required_Library libname)
|
||||||
|
|
||||||
FIND_PACKAGE(BLAS)
|
FIND_PACKAGE(BLAS)
|
||||||
|
SET(AT_MKL_ENABLED 0)
|
||||||
IF(BLAS_FOUND)
|
IF(BLAS_FOUND)
|
||||||
SET(USE_BLAS 1)
|
SET(USE_BLAS 1)
|
||||||
IF(BLAS_INFO STREQUAL "mkl")
|
IF(BLAS_INFO STREQUAL "mkl")
|
||||||
ADD_DEFINITIONS(-DTH_BLAS_MKL)
|
ADD_DEFINITIONS(-DTH_BLAS_MKL)
|
||||||
|
INCLUDE_DIRECTORIES(${BLAS_INCLUDE_DIR}) # include MKL headers
|
||||||
|
SET(AT_MKL_ENABLED 1)
|
||||||
ENDIF()
|
ENDIF()
|
||||||
ENDIF(BLAS_FOUND)
|
ENDIF(BLAS_FOUND)
|
||||||
|
|
||||||
|
@ -128,12 +128,14 @@ FILE(GLOB base_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
|
|||||||
FILE(GLOB native_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/*.cpp")
|
FILE(GLOB native_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/*.cpp")
|
||||||
FILE(GLOB native_cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cudnn/*.cpp")
|
FILE(GLOB native_cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cudnn/*.cpp")
|
||||||
FILE(GLOB native_cuda_cu RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cuda/*.cu")
|
FILE(GLOB native_cuda_cu RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cuda/*.cu")
|
||||||
|
FILE(GLOB native_mkl_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/mkl/*.cpp")
|
||||||
|
|
||||||
FILE(GLOB_RECURSE cuda_h
|
FILE(GLOB_RECURSE cuda_h
|
||||||
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
|
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||||
"cuda/*.cuh" "cuda/*.h" "cudnn/*.cuh" "cudnn/*.h")
|
"cuda/*.cuh" "cuda/*.h" "cudnn/*.cuh" "cudnn/*.h")
|
||||||
|
|
||||||
FILE(GLOB cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cudnn/*.cpp")
|
FILE(GLOB cudnn_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cudnn/*.cpp")
|
||||||
|
FILE(GLOB mkl_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "mkl/*.cpp")
|
||||||
|
|
||||||
FILE(GLOB all_python RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
|
FILE(GLOB all_python RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
|
||||||
|
|
||||||
@ -179,8 +181,7 @@ ADD_CUSTOM_TARGET(aten_files_are_generated
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SET(all_cpp ${base_cpp} ${native_cpp} ${native_cudnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp})
|
SET(all_cpp ${base_cpp} ${native_cpp} ${native_cudnn_cpp} ${native_mkl_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp})
|
||||||
|
|
||||||
|
|
||||||
INCLUDE_DIRECTORIES(${ATen_CPU_INCLUDE})
|
INCLUDE_DIRECTORIES(${ATen_CPU_INCLUDE})
|
||||||
IF(NOT NO_CUDA)
|
IF(NOT NO_CUDA)
|
||||||
@ -192,6 +193,9 @@ IF(NOT NO_CUDA)
|
|||||||
IF(CUDNN_FOUND)
|
IF(CUDNN_FOUND)
|
||||||
SET(all_cpp ${all_cpp} ${cudnn_cpp})
|
SET(all_cpp ${all_cpp} ${cudnn_cpp})
|
||||||
ENDIF()
|
ENDIF()
|
||||||
|
IF(AT_MKL_ENABLED)
|
||||||
|
SET(all_cpp ${all_cpp} ${mkl_cpp})
|
||||||
|
ENDIF()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
filter_list(generated_h generated_cpp "\\.h$")
|
filter_list(generated_h generated_cpp "\\.h$")
|
||||||
@ -309,6 +313,7 @@ IF(CUDA_FOUND)
|
|||||||
${CUDA_cusparse_LIBRARY}
|
${CUDA_cusparse_LIBRARY}
|
||||||
${CUDA_curand_LIBRARY})
|
${CUDA_curand_LIBRARY})
|
||||||
CUDA_ADD_CUBLAS_TO_TARGET(ATen)
|
CUDA_ADD_CUBLAS_TO_TARGET(ATen)
|
||||||
|
CUDA_ADD_CUFFT_TO_TARGET(ATen)
|
||||||
|
|
||||||
if(CUDNN_FOUND)
|
if(CUDNN_FOUND)
|
||||||
target_link_libraries(ATen ${CUDNN_LIBRARIES})
|
target_link_libraries(ATen ${CUDNN_LIBRARIES})
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// Test these using #if AT_CUDA_ENABLED()(), not #ifdef, so that it's
|
// Test these using #if AT_CUDA_ENABLED(), not #ifdef, so that it's
|
||||||
// obvious if you forgot to include Config.h
|
// obvious if you forgot to include Config.h
|
||||||
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
|
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
|
||||||
|
|
||||||
#define AT_CUDA_ENABLED() @AT_CUDA_ENABLED@
|
#define AT_CUDA_ENABLED() @AT_CUDA_ENABLED@
|
||||||
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
|
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
|
||||||
#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@
|
#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@
|
||||||
|
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
|
||||||
|
|
||||||
#if !AT_CUDA_ENABLED() && AT_CUDNN_ENABLED()
|
#if !AT_CUDA_ENABLED() && AT_CUDNN_ENABLED()
|
||||||
#error "Cannot enable CuDNN without CUDA"
|
#error "Cannot enable CuDNN without CUDA"
|
||||||
|
@ -89,6 +89,14 @@ void Context::setBenchmarkCuDNN(bool b) {
|
|||||||
benchmark_cudnn = b;
|
benchmark_cudnn = b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Context::hasMKL() const {
|
||||||
|
#if AT_MKL_ENABLED()
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
bool Context::hasCUDA() const {
|
bool Context::hasCUDA() const {
|
||||||
#if AT_CUDA_ENABLED()
|
#if AT_CUDA_ENABLED()
|
||||||
int count;
|
int count;
|
||||||
|
@ -43,6 +43,7 @@ public:
|
|||||||
runtime_error("%s backend type not enabled.",toString(p));
|
runtime_error("%s backend type not enabled.",toString(p));
|
||||||
return *generator;
|
return *generator;
|
||||||
}
|
}
|
||||||
|
bool hasMKL() const;
|
||||||
bool hasCUDA() const;
|
bool hasCUDA() const;
|
||||||
int64_t current_device() const;
|
int64_t current_device() const;
|
||||||
// defined in header so that getType has ability to inline
|
// defined in header so that getType has ability to inline
|
||||||
@ -103,7 +104,7 @@ static inline void init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline Type& getType(Backend p, ScalarType s) {
|
static inline Type& getType(Backend p, ScalarType s) {
|
||||||
return globalContext().getType(p,s);
|
return globalContext().getType(p, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline Type& CPU(ScalarType s) {
|
static inline Type& CPU(ScalarType s) {
|
||||||
@ -118,6 +119,10 @@ static inline bool hasCUDA() {
|
|||||||
return globalContext().hasCUDA();
|
return globalContext().hasCUDA();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool hasMKL() {
|
||||||
|
return globalContext().hasMKL();
|
||||||
|
}
|
||||||
|
|
||||||
static inline int64_t current_device() {
|
static inline int64_t current_device() {
|
||||||
return globalContext().current_device();
|
return globalContext().current_device();
|
||||||
}
|
}
|
||||||
|
44
aten/src/ATen/mkl/Descriptors.h
Normal file
44
aten/src/ATen/mkl/Descriptors.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Exceptions.h"
|
||||||
|
#include <mkl_dfti.h>
|
||||||
|
#include <ATen/Tensor.h>
|
||||||
|
|
||||||
|
namespace at { namespace native {
|
||||||
|
|
||||||
|
struct DftiDescriptorDeleter {
|
||||||
|
void operator()(DFTI_DESCRIPTOR* desc) {
|
||||||
|
if (desc != nullptr) {
|
||||||
|
MKL_DFTI_CHECK(DftiFreeDescriptor(&desc));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DftiDescriptor {
|
||||||
|
public:
|
||||||
|
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) {
|
||||||
|
if (desc_ != nullptr) {
|
||||||
|
throw std::runtime_error("DFTI DESCRIPTOR can only be initialized once");
|
||||||
|
}
|
||||||
|
DFTI_DESCRIPTOR *raw_desc;
|
||||||
|
if (signal_ndim == 1) {
|
||||||
|
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
|
||||||
|
} else {
|
||||||
|
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, signal_ndim, sizes));
|
||||||
|
}
|
||||||
|
desc_.reset(raw_desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
DFTI_DESCRIPTOR *get() const {
|
||||||
|
if (desc_ == nullptr) {
|
||||||
|
throw std::runtime_error("DFTI DESCRIPTOR has not been initialized");
|
||||||
|
}
|
||||||
|
return desc_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
}} // at::native
|
19
aten/src/ATen/mkl/Exceptions.h
Normal file
19
aten/src/ATen/mkl/Exceptions.h
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <sstream>
|
||||||
|
#include <mkl_dfti.h>
|
||||||
|
|
||||||
|
namespace at { namespace native {
|
||||||
|
|
||||||
|
static inline void MKL_DFTI_CHECK(MKL_INT status)
|
||||||
|
{
|
||||||
|
if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "MKL FFT error: " << DftiErrorMessage(status);
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}} // namespace at::native
|
11
aten/src/ATen/mkl/Limits.h
Normal file
11
aten/src/ATen/mkl/Limits.h
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <mkl_types.h>
|
||||||
|
|
||||||
|
namespace at { namespace native {
|
||||||
|
|
||||||
|
// Since size of MKL_LONG varies on different platforms (linux 64 bit, windows
|
||||||
|
// 32 bit), we need to programmatically calculate the max.
|
||||||
|
static int64_t MKL_LONG_MAX = ((1LL << (sizeof(MKL_LONG) * 8 - 2)) - 1) * 2 + 1;
|
||||||
|
|
||||||
|
}} // namespace
|
4
aten/src/ATen/mkl/README.md
Normal file
4
aten/src/ATen/mkl/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
All files living in this directory are written with the assumption that MKL is available,
|
||||||
|
which means that these code are not guarded by `#if AT_MKL_ENABLED()`. Therefore, whenever
|
||||||
|
you need to use definitions from here, please guard the `#include<ATen/mkl/*.h>` and
|
||||||
|
definition usages with `#if AT_MKL_ENABLED()` macro, e.g. [SpectralOps.cpp](native/mkl/SpectralOps.cpp).
|
84
aten/src/ATen/native/cuda/CuFFTUtils.h
Normal file
84
aten/src/ATen/native/cuda/CuFFTUtils.h
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
#include "ATen/Config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <sstream>
|
||||||
|
#include <cufft.h>
|
||||||
|
#include <cufftXt.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace at { namespace native {
|
||||||
|
|
||||||
|
static inline std::string _cudaGetErrorEnum(cufftResult error)
|
||||||
|
{
|
||||||
|
switch (error)
|
||||||
|
{
|
||||||
|
case CUFFT_SUCCESS:
|
||||||
|
return "CUFFT_SUCCESS";
|
||||||
|
case CUFFT_INVALID_PLAN:
|
||||||
|
return "CUFFT_INVALID_PLAN";
|
||||||
|
case CUFFT_ALLOC_FAILED:
|
||||||
|
return "CUFFT_ALLOC_FAILED";
|
||||||
|
case CUFFT_INVALID_TYPE:
|
||||||
|
return "CUFFT_INVALID_TYPE";
|
||||||
|
case CUFFT_INVALID_VALUE:
|
||||||
|
return "CUFFT_INVALID_VALUE";
|
||||||
|
case CUFFT_INTERNAL_ERROR:
|
||||||
|
return "CUFFT_INTERNAL_ERROR";
|
||||||
|
case CUFFT_EXEC_FAILED:
|
||||||
|
return "CUFFT_EXEC_FAILED";
|
||||||
|
case CUFFT_SETUP_FAILED:
|
||||||
|
return "CUFFT_SETUP_FAILED";
|
||||||
|
case CUFFT_INVALID_SIZE:
|
||||||
|
return "CUFFT_INVALID_SIZE";
|
||||||
|
case CUFFT_UNALIGNED_DATA:
|
||||||
|
return "CUFFT_UNALIGNED_DATA";
|
||||||
|
case CUFFT_INCOMPLETE_PARAMETER_LIST:
|
||||||
|
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
|
||||||
|
case CUFFT_INVALID_DEVICE:
|
||||||
|
return "CUFFT_INVALID_DEVICE";
|
||||||
|
case CUFFT_PARSE_ERROR:
|
||||||
|
return "CUFFT_PARSE_ERROR";
|
||||||
|
case CUFFT_NO_WORKSPACE:
|
||||||
|
return "CUFFT_NO_WORKSPACE";
|
||||||
|
case CUFFT_NOT_IMPLEMENTED:
|
||||||
|
return "CUFFT_NOT_IMPLEMENTED";
|
||||||
|
case CUFFT_LICENSE_ERROR:
|
||||||
|
return "CUFFT_LICENSE_ERROR";
|
||||||
|
case CUFFT_NOT_SUPPORTED:
|
||||||
|
return "CUFFT_NOT_SUPPORTED";
|
||||||
|
default:
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "unknown error " << error;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void CUFFT_CHECK(cufftResult error)
|
||||||
|
{
|
||||||
|
if (error != CUFFT_SUCCESS) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "cuFFT error: " << _cudaGetErrorEnum(error);
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class CufftHandle {
|
||||||
|
public:
|
||||||
|
explicit CufftHandle() {
|
||||||
|
CUFFT_CHECK(cufftCreate(&raw_plan));
|
||||||
|
}
|
||||||
|
|
||||||
|
const cufftHandle &get() const { return raw_plan; }
|
||||||
|
|
||||||
|
~CufftHandle() {
|
||||||
|
CUFFT_CHECK(cufftDestroy(raw_plan));
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
cufftHandle raw_plan;
|
||||||
|
};
|
||||||
|
|
||||||
|
}} // at::native
|
@ -8,4 +8,8 @@
|
|||||||
#error "AT_CUDNN_ENABLED should not be visible in public headers"
|
#error "AT_CUDNN_ENABLED should not be visible in public headers"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef AT_MKL_ENABLED
|
||||||
|
#error "AT_MKL_ENABLED should not be visible in public headers"
|
||||||
|
#endif
|
||||||
|
|
||||||
auto main() -> int {}
|
auto main() -> int {}
|
||||||
|
@ -180,7 +180,7 @@ IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED)
|
|||||||
ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED)
|
ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED)
|
||||||
IF(NOT LAPACK_FIND_QUIETLY)
|
IF(NOT LAPACK_FIND_QUIETLY)
|
||||||
IF(LAPACK_FOUND)
|
IF(LAPACK_FOUND)
|
||||||
MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})")
|
MESSAGE(STATUS "Found a library with LAPACK API (${LAPACK_INFO}).")
|
||||||
ELSE(LAPACK_FOUND)
|
ELSE(LAPACK_FOUND)
|
||||||
MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.")
|
MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.")
|
||||||
ENDIF(LAPACK_FOUND)
|
ENDIF(LAPACK_FOUND)
|
||||||
|
@ -175,7 +175,7 @@ FOREACH(mklrtl ${mklrtls} "")
|
|||||||
IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
|
IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
|
||||||
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
||||||
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "")
|
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "")
|
||||||
ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
|
ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
|
||||||
ENDFOREACH(mklthread)
|
ENDFOREACH(mklthread)
|
||||||
ENDFOREACH(mkl64)
|
ENDFOREACH(mkl64)
|
||||||
ENDFOREACH(mkliface)
|
ENDFOREACH(mkliface)
|
||||||
@ -200,7 +200,7 @@ FOREACH(mklrtl ${mklrtls} "")
|
|||||||
IF (NOT MKL_LIBRARIES)
|
IF (NOT MKL_LIBRARIES)
|
||||||
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
||||||
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "")
|
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;${mkl_m};${mkl_dl}" "")
|
||||||
ENDIF (NOT MKL_LIBRARIES)
|
ENDIF (NOT MKL_LIBRARIES)
|
||||||
ENDFOREACH(mklthread)
|
ENDFOREACH(mklthread)
|
||||||
ENDFOREACH(mkl64)
|
ENDFOREACH(mkl64)
|
||||||
ENDFOREACH(mkliface)
|
ENDFOREACH(mkliface)
|
||||||
@ -211,7 +211,7 @@ IF (NOT MKL_LIBRARIES)
|
|||||||
SET(MKL_VERSION 900)
|
SET(MKL_VERSION 900)
|
||||||
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
||||||
"mkl;guide;pthread;m" "")
|
"mkl;guide;pthread;m" "")
|
||||||
ENDIF (NOT MKL_LIBRARIES)
|
ENDIF (NOT MKL_LIBRARIES)
|
||||||
|
|
||||||
# Include files
|
# Include files
|
||||||
IF (MKL_LIBRARIES)
|
IF (MKL_LIBRARIES)
|
||||||
@ -228,7 +228,7 @@ IF (MKL_LIBRARIES)
|
|||||||
MARK_AS_ADVANCED(MKL_LAPACK_LIBRARIES)
|
MARK_AS_ADVANCED(MKL_LAPACK_LIBRARIES)
|
||||||
ENDIF (NOT MKL_LAPACK_LIBRARIES)
|
ENDIF (NOT MKL_LAPACK_LIBRARIES)
|
||||||
IF (NOT MKL_SCALAPACK_LIBRARIES)
|
IF (NOT MKL_SCALAPACK_LIBRARIES)
|
||||||
FIND_LIBRARY(MKL_SCALAPACK_LIBRARIES NAMES "mkl_scalapack${mkl64}${mkls}")
|
FIND_LIBRARY(MKL_SCALAPACK_LIBRARIES NAMES "mkl_scalapack${mkl64}${mkls}")
|
||||||
MARK_AS_ADVANCED(MKL_SCALAPACK_LIBRARIES)
|
MARK_AS_ADVANCED(MKL_SCALAPACK_LIBRARIES)
|
||||||
ENDIF (NOT MKL_SCALAPACK_LIBRARIES)
|
ENDIF (NOT MKL_SCALAPACK_LIBRARIES)
|
||||||
IF (NOT MKL_SOLVER_LIBRARIES)
|
IF (NOT MKL_SOLVER_LIBRARIES)
|
||||||
@ -243,7 +243,7 @@ IF (MKL_LIBRARIES)
|
|||||||
ENDFOREACH(mkl64)
|
ENDFOREACH(mkl64)
|
||||||
ENDIF (MKL_LIBRARIES)
|
ENDIF (MKL_LIBRARIES)
|
||||||
|
|
||||||
# LibIRC: intel compiler always links this;
|
# LibIRC: intel compiler always links this;
|
||||||
# gcc does not; but mkl kernels sometimes need it.
|
# gcc does not; but mkl kernels sometimes need it.
|
||||||
IF (MKL_LIBRARIES)
|
IF (MKL_LIBRARIES)
|
||||||
IF (CMAKE_COMPILER_IS_GNUCC)
|
IF (CMAKE_COMPILER_IS_GNUCC)
|
||||||
@ -269,7 +269,7 @@ ENDIF (MKL_LIBRARIES)
|
|||||||
|
|
||||||
# Standard termination
|
# Standard termination
|
||||||
IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
|
IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
|
||||||
MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location")
|
MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location")
|
||||||
ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
|
ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
|
||||||
IF(NOT MKL_FIND_QUIETLY)
|
IF(NOT MKL_FIND_QUIETLY)
|
||||||
IF(MKL_FOUND)
|
IF(MKL_FOUND)
|
||||||
|
@ -19,6 +19,7 @@ import torch.cuda
|
|||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch._six import string_classes
|
from torch._six import string_classes
|
||||||
import torch.backends.cudnn
|
import torch.backends.cudnn
|
||||||
|
import torch.backends.mkl
|
||||||
|
|
||||||
|
|
||||||
torch.set_default_tensor_type('torch.DoubleTensor')
|
torch.set_default_tensor_type('torch.DoubleTensor')
|
||||||
@ -54,6 +55,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
TEST_SCIPY = False
|
TEST_SCIPY = False
|
||||||
|
|
||||||
|
TEST_MKL = torch.backends.mkl.is_available()
|
||||||
|
|
||||||
|
|
||||||
def skipIfNoLapack(fn):
|
def skipIfNoLapack(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
|
6
torch/backends/mkl/__init__.py
Normal file
6
torch/backends/mkl/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def is_available():
|
||||||
|
r"""Returns whether PyTorch is built with MKL support."""
|
||||||
|
return torch._C.has_mkl
|
@ -505,6 +505,8 @@ static PyObject* initModule() {
|
|||||||
// setting up TH Errors so that they throw C++ exceptions
|
// setting up TH Errors so that they throw C++ exceptions
|
||||||
at::init();
|
at::init();
|
||||||
|
|
||||||
|
ASSERT_TRUE(PyModule_AddObject(module, "has_mkl", at::hasMKL() ? Py_True : Py_False) == 0);
|
||||||
|
|
||||||
auto& defaultGenerator = at::globalContext().defaultGenerator(at::kCPU);
|
auto& defaultGenerator = at::globalContext().defaultGenerator(at::kCPU);
|
||||||
THPDefaultGenerator = (THPGenerator*)THPGenerator_NewWithGenerator(
|
THPDefaultGenerator = (THPGenerator*)THPGenerator_NewWithGenerator(
|
||||||
defaultGenerator);
|
defaultGenerator);
|
||||||
|
Reference in New Issue
Block a user