mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow one to set the blas backend, while optionally choosing to use
Eigen for the whole numerical computation (for example, on a platform where there is no optimized BLAS libraries present, or Eigen is already the fastest numerical library existing). The paths I have tested is Eigen and atlas. Have not tested MKL yet.
This commit is contained in:
12
build.py
12
build.py
@ -41,6 +41,18 @@ class Config(object):
|
||||
# compiling) you may want to set USE_SYSTEM_EIGEN to False.
|
||||
USE_SYSTEM_EIGEN = False
|
||||
|
||||
# BLAS functions: whether to use Eigen for all the BLAS calls. In platforms
|
||||
# that do not have an optimized BLAS library, this would usually solve the
|
||||
# problem as Eigen will generate all the code. Optionally, one can specify
|
||||
# a separately compiled BLAS library (such as MKL) and we will use that
|
||||
# library for all BLAS calls.
|
||||
USE_EIGEN_FOR_BLAS = True
|
||||
# If you have set the above flag to False, you should specify a BLAS backend
|
||||
# here. Note that, if the BLAS backend is MKL, we will also assume that the
|
||||
# MKL VSL library is present, and we will use the VSL function calls as
|
||||
# well. If USE_EIGEN_FOR_BLAS is True, this config has no effect.
|
||||
BLAS_BACKEND = "atlas"
|
||||
|
||||
# google-glog: Caffe can choose to use google glog, which will allow a more
|
||||
# sophisticated logging scheme. It also comes with a minimal logging tool
|
||||
# that does not depend on glog. If you wish to use glog, set USE_GLOG to
|
||||
|
@ -25,7 +25,7 @@ class FullyConnectedOp final : public Operator<Context> {
|
||||
CAFFE_CHECK_GE(W.ndim(), 2);
|
||||
if (X.ndim() > 2 || W.ndim() > 2) {
|
||||
CAFFE_VLOG(1) << "Using legacy support for arbitrary input and weight "
|
||||
<< "dimensions.";
|
||||
"dimensions.";
|
||||
}
|
||||
CAFFE_CHECK_EQ(b.ndim(), 1);
|
||||
// batch size
|
||||
|
@ -6,10 +6,10 @@ cc_library(
|
||||
hdrs = [
|
||||
"cblas.h",
|
||||
"math.h",
|
||||
"mkl_alternate.h",
|
||||
],
|
||||
deps = [
|
||||
"//third_party/eigen3:eigen",
|
||||
"//third_party/blas:blas",
|
||||
"//caffe2/core:core",
|
||||
],
|
||||
)
|
||||
|
@ -1,14 +1,31 @@
|
||||
// Implementes the math functions for CPU.
|
||||
// The implementation in this file allows us to route the underlying numerical
|
||||
// computation library to different backends. Notably:
|
||||
// (1) For all BLAS-related functions, one can explicitly request a BLAS backend
|
||||
// such as MKL, openblas or Atlas. To see the set of supported backends
|
||||
// currently provided, check //third_party/blas/.
|
||||
// (2) If one chooses to link against MKL, we utilize MKL's vector math library
|
||||
// (VML) for a few functions such as Exp and Log.
|
||||
// (3) Fallback implementations are provided in Eigen for cross-platform
|
||||
// support. Since Eigen is a header-only library and supports a number of
|
||||
// platforms, it allows one to quickly port Caffe2 to different platforms
|
||||
// where BLAS may not be present.
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "caffe2/utils/mkl_alternate.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "eigen3/Eigen/Core"
|
||||
#include "eigen3/Eigen/Dense"
|
||||
|
||||
namespace {
|
||||
#ifdef CAFFE2_USE_MKL
|
||||
#include <mkl.h>
|
||||
#else // CAFFE2_USE_MKL
|
||||
#include "caffe2/utils/cblas.h"
|
||||
#endif // CAFFE2_USE_MKL
|
||||
|
||||
// Common Eigen types that we will often use
|
||||
namespace {
|
||||
template <typename T>
|
||||
using EigenMatrixMap =
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> >;
|
||||
@ -25,82 +42,13 @@ using ConstEigenVectorMap =
|
||||
namespace caffe2 {
|
||||
namespace math {
|
||||
|
||||
#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, OriginalFunc) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>( \
|
||||
const int N, const T* x, T* y, \
|
||||
CPUContext* context) { \
|
||||
OriginalFunc(N, x, y); \
|
||||
}
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, vsExp)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, vdExp)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, vsLn)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, vdLn)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, vdSqr)
|
||||
|
||||
template <>
|
||||
void Powx<float, CPUContext>(
|
||||
const int N, const float* a, float b, float* y, CPUContext* context) {
|
||||
vsPowx(N, a, b, y);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Powx<double, CPUContext>(
|
||||
const int N, const double* a, double b, double* y, CPUContext* context) {
|
||||
vdPowx(N, a, b, y);
|
||||
}
|
||||
|
||||
|
||||
#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Funcname, OriginalFunc) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>( \
|
||||
const int N, const T* a, const T* b, T* y, \
|
||||
CPUContext* context) { \
|
||||
OriginalFunc(N, a, b, y); \
|
||||
}
|
||||
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, vsSub)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, vdSub)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, vsMul)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, vdMul)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv)
|
||||
#undef DELEGATE_SIMPLE_BINARY_FUNCTION
|
||||
|
||||
#define CAFFE2_SPECIALIZED_ROWWISEMAX(T) \
|
||||
template <> void RowwiseMax<T, CPUContext>( \
|
||||
const int N, const int D, const T* x, T* y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, N) = \
|
||||
ConstEigenMatrixMap<T>(x, D, N).colwise().maxCoeff(); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_ROWWISEMAX(float)
|
||||
|
||||
#define CAFFE2_SPECIALIZED_COLWISEMAX(T) \
|
||||
template <> void ColwiseMax<T, CPUContext>( \
|
||||
const int N, const int D, const T* x, T* y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, D) = \
|
||||
ConstEigenMatrixMap<T>(x, D, N).rowwise().maxCoeff(); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_COLWISEMAX(float)
|
||||
|
||||
// AddToRow and AddToCol adds the corresponding row/col vector x to the matrix y
|
||||
// of shape M x N. The actual implementation uses eigen which is column major,
|
||||
// so notice the row/column swap in the actual implementation.
|
||||
template <>
|
||||
void AddToRow<float, CPUContext>(
|
||||
const int M, const int N, const float* x, float* y, CPUContext* context) {
|
||||
EigenMatrixMap<float>(y, N, M).colwise() += ConstEigenVectorMap<float>(x, N);
|
||||
}
|
||||
template <>
|
||||
void AddToCol<float, CPUContext>(
|
||||
const int M, const int N, const float* x, float* y, CPUContext* context) {
|
||||
EigenMatrixMap<float>(y, N, M).rowwise() +=
|
||||
ConstEigenVectorMap<float>(x, M).transpose();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// BLAS alternatives.
|
||||
// Depending on whether we have specified an external BLAS library or not, we
|
||||
// will delegate the Caffe math functions that are BLAS-related to either the
|
||||
// CBLAS call or the Eigen implementation.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#ifdef CAFFE2_USE_EIGEN_FOR_BLAS
|
||||
|
||||
// Caffe2 gemm provides a simpler interface to the gemm functions, with the
|
||||
// limitation that the data has to be contiguous in memory.
|
||||
@ -185,11 +133,290 @@ void Gemv<float, CPUContext>(
|
||||
}
|
||||
}
|
||||
|
||||
#define CAFFE2_SPECIALIZED_SET(T) \
|
||||
#define CAFFE2_SPECIALIZED_SCALE(T) \
|
||||
template <> \
|
||||
void Set<T, CPUContext>(const int N, const T alpha, T *Y, \
|
||||
void Scale<T, CPUContext>( \
|
||||
const int n, const T alpha, const T* x, T* y, \
|
||||
CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha; \
|
||||
} \
|
||||
template <> \
|
||||
void Scale<T, CPUContext>( \
|
||||
const int n, const T* alpha, const T* x, T* y, \
|
||||
CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_SCALE(float)
|
||||
CAFFE2_SPECIALIZED_SCALE(double)
|
||||
#undef CAFFE2_SPECIALIZED_SCALE
|
||||
|
||||
#define CAFFE2_SPECIALIZED_DOT(T) \
|
||||
template<> \
|
||||
void Dot<T, CPUContext>( \
|
||||
const int N, const T* a, const T* b, T* y, \
|
||||
CPUContext* context) { \
|
||||
*y = ConstEigenVectorMap<T>(a, N).dot(ConstEigenVectorMap<T>(b, N)); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_DOT(float)
|
||||
CAFFE2_SPECIALIZED_DOT(double)
|
||||
#undef CAFFE2_SPECIALIZED_DOT
|
||||
|
||||
#define CAFFE2_SPECIALIZED_AXPY(T) \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>(const int N, const T alpha, const T* x, \
|
||||
T* Y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * alpha; \
|
||||
} \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>(const int N, const T* alpha, const T* x, \
|
||||
T* Y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * (*alpha); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_AXPY(float)
|
||||
CAFFE2_SPECIALIZED_AXPY(double)
|
||||
#undef CAFFE2_SPECIALIZED_AXPY
|
||||
|
||||
#define CAFFE2_SPECIALIZED_AXPBY(T) \
|
||||
template <> \
|
||||
void Axpby<T, CPUContext>(const int N, const T alpha, const T* x, \
|
||||
const T beta, T* y, CPUContext* context) { \
|
||||
EigenVectorMap<T> y_vec(y, N); \
|
||||
y_vec = y_vec * beta + ConstEigenVectorMap<T>(x, N) * alpha; \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_AXPBY(float)
|
||||
CAFFE2_SPECIALIZED_AXPBY(double)
|
||||
#undef CAFFE2_SPECIALIZED_AXPBY
|
||||
|
||||
#else // CAFFE2_USE_EIGEN_FOR_BLAS
|
||||
|
||||
template <>
|
||||
void Gemm<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
|
||||
const int M, const int N, const int K, const float alpha, const float* A,
|
||||
const float* B, const float beta, float* C, CPUContext* context) {
|
||||
int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb,
|
||||
beta, C, N);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemv<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha,
|
||||
const float* A, const float* x, const float beta, float* y,
|
||||
CPUContext* context) {
|
||||
cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
|
||||
}
|
||||
|
||||
#define CAFFE2_SPECIALIZED_SCALE(T, prefix) \
|
||||
template <> \
|
||||
void Scale<T, CPUContext>(const int n, const T alpha, const T *x, T* y, \
|
||||
CPUContext* context) { \
|
||||
if (y != x) cblas_##prefix##copy(n, x, 1, y, 1); \
|
||||
cblas_##prefix##scal(n, alpha, y, 1); \
|
||||
} \
|
||||
template <> \
|
||||
void Scale<T, CPUContext>(const int n, const T* alpha, const T*x, T* y, \
|
||||
CPUContext* context) { \
|
||||
if (y != x) cblas_##prefix##copy(n, x, 1, y, 1); \
|
||||
cblas_##prefix##scal(n, *alpha, y, 1); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_SCALE(float, s)
|
||||
CAFFE2_SPECIALIZED_SCALE(double, d)
|
||||
#undef CAFFE2_SPECIALIZED_SCALE
|
||||
|
||||
#define CAFFE2_SPECIALIZED_DOT(T, prefix) \
|
||||
template<> \
|
||||
void Dot<T, CPUContext>( \
|
||||
const int N, const T* a, const T* b, T* y, \
|
||||
CPUContext* context) { \
|
||||
*y = cblas_##prefix##dot(N, a, 1, b, 1); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_DOT(float, s)
|
||||
CAFFE2_SPECIALIZED_DOT(double, d)
|
||||
#undef CAFFE2_SPECIALIZED_DOT
|
||||
|
||||
#define CAFFE2_SPECIALIZED_AXPY(T, prefix) \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>(const int N, const T alpha, const T* x, \
|
||||
T* y, CPUContext* context) { \
|
||||
cblas_##prefix##axpy(N, alpha, x, 1, y, 1); \
|
||||
} \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>(const int N, const T* alpha, const T* x, \
|
||||
T* y, CPUContext* context) { \
|
||||
cblas_##prefix##axpy(N, *alpha, x, 1, y, 1); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_AXPY(float, s)
|
||||
CAFFE2_SPECIALIZED_AXPY(double, d)
|
||||
#undef CAFFE2_SPECIALIZED_AXPY
|
||||
|
||||
// cblas_[sd]axpby is not a standard blas function, and if MKL is not present,
|
||||
// we will need to implement it.
|
||||
#ifdef CAFFE2_USE_MKL
|
||||
#define CAFFE2_SPECIALIZED_AXPBY(T, prefix) \
|
||||
template <> \
|
||||
void Axpby<T, CPUContext>(const int N, const T alpha, const T* x, \
|
||||
const T beta, T* y, CPUContext* context) { \
|
||||
cblas_##prefix##axpby(N, alpha, X, 1, beta, Y, 1); \
|
||||
}
|
||||
#else // CAFFE2_USE_MKL
|
||||
#define CAFFE2_SPECIALIZED_AXPBY(T, prefix) \
|
||||
template <> \
|
||||
void Axpby<T, CPUContext>(const int N, const T alpha, const T* x, \
|
||||
const T beta, T* y, CPUContext* context) { \
|
||||
cblas_##prefix##scal(N, beta, y, 1); \
|
||||
cblas_##prefix##axpy(N, alpha, x, 1, y, 1); \
|
||||
}
|
||||
#endif // CAFFE2_USE_MKL
|
||||
CAFFE2_SPECIALIZED_AXPBY(float, s)
|
||||
CAFFE2_SPECIALIZED_AXPBY(double, d)
|
||||
#undef CAFFE2_SPECIALIZED_AXPBY
|
||||
|
||||
#endif // CAFFE2_USE_EIGEN_FOR_BLAS
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// MKL VML alternatives.
|
||||
// Depending on whether we are using MKL, we will delegate the Caffe math
|
||||
// functions that are VML-related to either the VML call or the Eigen
|
||||
// implementation. If you are setting the flags (such as AVX) right for your CPU
|
||||
// architecture, usually Eigen will deliver a throughput as fast as the VML
|
||||
// functions.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#ifdef CAFFE2_USE_MKL
|
||||
|
||||
#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, OriginalFunc) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>( \
|
||||
const int N, const T* x, T* y, \
|
||||
CPUContext* context) { \
|
||||
OriginalFunc(N, x, y); \
|
||||
}
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, vsExp)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, vdExp)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, vsLn)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, vdLn)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, vdSqr)
|
||||
#undef DELEGATE_SIMPLE_UNARY_FUNCTION
|
||||
|
||||
#define DELEGATE_POWX_FUNCTION(T, OriginalFunc) \
|
||||
template <> \
|
||||
void Powx<T, CPUContext>( \
|
||||
const int N, const T* a, T b, T* y, CPUContext* context) { \
|
||||
OriginalFunc(N, a, b, y); \
|
||||
}
|
||||
DELEGATE_POWX_FUNCTION(float, vsPowx)
|
||||
DELEGATE_POWX_FUNCTION(double, vdPowx)
|
||||
#undef DELEGATE_POWX_FUNCTION
|
||||
|
||||
#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Funcname, OriginalFunc) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>( \
|
||||
const int N, const T* a, const T* b, T* y, \
|
||||
CPUContext* context) { \
|
||||
OriginalFunc(N, a, b, y); \
|
||||
}
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, vsSub)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, vdSub)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, vsMul)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, vdMul)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv)
|
||||
#undef DELEGATE_SIMPLE_BINARY_FUNCTION
|
||||
|
||||
#else // CAFFE2_USE_MKL
|
||||
|
||||
#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, expr) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>(const int N, const T* x, T* y, \
|
||||
CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(x, N).array().expr(); \
|
||||
}
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, exp)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, exp)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, log)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, log)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, square)
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, square)
|
||||
#undef DELEGATE_SIMPLE_UNARY_FUNCTION
|
||||
|
||||
#define DELEGATE_POWX_FUNCTION(T) \
|
||||
template <> \
|
||||
void Powx<T, CPUContext>( \
|
||||
const int N, const T* a, T b, T* y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(a, N).array().pow(b); \
|
||||
}
|
||||
DELEGATE_POWX_FUNCTION(float)
|
||||
DELEGATE_POWX_FUNCTION(double)
|
||||
#undef DELEGATE_POWX_FUNCTION
|
||||
|
||||
#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Funcname, expr) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>( \
|
||||
const int N, const T* a, const T* b, T* y, \
|
||||
CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, N) = \
|
||||
ConstEigenVectorMap<T>(a, N).array() expr \
|
||||
ConstEigenVectorMap<T>(b, N).array(); \
|
||||
}
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, +)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, +)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, -)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, -)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, *)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, *)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, /)
|
||||
DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, /)
|
||||
#undef DELEGATE_SIMPLE_BINARY_FUNCTION
|
||||
|
||||
#endif // CAFFE2_USE_MKL
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Common math functions being used in Caffe that do not have a BLAS or MKL
|
||||
// equivalent. For all these functions, we will simply implement them either via
|
||||
// Eigen or via custom code.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CAFFE2_SPECIALIZED_ROWWISEMAX(T) \
|
||||
template <> void RowwiseMax<T, CPUContext>( \
|
||||
const int N, const int D, const T* x, T* y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, N) = \
|
||||
ConstEigenMatrixMap<T>(x, D, N).colwise().maxCoeff(); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_ROWWISEMAX(float)
|
||||
|
||||
#define CAFFE2_SPECIALIZED_COLWISEMAX(T) \
|
||||
template <> void ColwiseMax<T, CPUContext>( \
|
||||
const int N, const int D, const T* x, T* y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(y, D) = \
|
||||
ConstEigenMatrixMap<T>(x, D, N).rowwise().maxCoeff(); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_COLWISEMAX(float)
|
||||
|
||||
// AddToRow and AddToCol adds the corresponding row/col vector x to the matrix y
|
||||
// of shape M x N. The actual implementation uses eigen which is column major,
|
||||
// so notice the row/column swap in the actual implementation.
|
||||
template <>
|
||||
void AddToRow<float, CPUContext>(
|
||||
const int M, const int N, const float* x, float* y, CPUContext* context) {
|
||||
EigenMatrixMap<float>(y, N, M).colwise() += ConstEigenVectorMap<float>(x, N);
|
||||
}
|
||||
template <>
|
||||
void AddToCol<float, CPUContext>(
|
||||
const int M, const int N, const float* x, float* y, CPUContext* context) {
|
||||
EigenMatrixMap<float>(y, N, M).rowwise() +=
|
||||
ConstEigenVectorMap<float>(x, M).transpose();
|
||||
}
|
||||
|
||||
#define CAFFE2_SPECIALIZED_SET(T) \
|
||||
template <> \
|
||||
void Set<T, CPUContext>(const int N, const T alpha, T *Y, \
|
||||
CPUContext* context) { \
|
||||
EigenVectorMap<T>(Y, N).setConstant(alpha); \
|
||||
EigenVectorMap<T>(Y, N).setConstant(alpha); \
|
||||
}
|
||||
|
||||
CAFFE2_SPECIALIZED_SET(float);
|
||||
@ -228,20 +455,6 @@ void RandGaussian<float, CPUContext>(
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void Dot<float, CPUContext>(
|
||||
const int N, const float* a, const float* b, float* y,
|
||||
CPUContext* context) {
|
||||
*y = ConstEigenVectorMap<float>(a, N).dot(ConstEigenVectorMap<float>(b, N));
|
||||
}
|
||||
|
||||
template<>
|
||||
void Dot<double, CPUContext>(
|
||||
const int N, const double* a, const double* b, double* y,
|
||||
CPUContext* context) {
|
||||
*y = ConstEigenVectorMap<double>(a, N).dot(ConstEigenVectorMap<double>(b, N));
|
||||
}
|
||||
|
||||
template<>
|
||||
void Sum<float, CPUContext>(
|
||||
const int N, const float* x, float* y,
|
||||
@ -266,75 +479,6 @@ void Select<float, CPUContext>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void Scale<float, CPUContext>(
|
||||
const int n, const float alpha, const float *x, float* y,
|
||||
CPUContext* context) {
|
||||
EigenVectorMap<float>(y, n) = ConstEigenVectorMap<float>(x, n) * alpha;
|
||||
}
|
||||
|
||||
template <>
|
||||
void Scale<double, CPUContext>(
|
||||
const int n, const double alpha, const double *x, double* y,
|
||||
CPUContext* context) {
|
||||
EigenVectorMap<double>(y, n) = ConstEigenVectorMap<double>(x, n) * alpha;
|
||||
}
|
||||
|
||||
template <>
|
||||
void Scale<float, CPUContext>(
|
||||
const int n, const float* alpha, const float* x, float* y,
|
||||
CPUContext* context) {
|
||||
EigenVectorMap<float>(y, n) = ConstEigenVectorMap<float>(x, n) * (*alpha);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Scale<double, CPUContext>(
|
||||
const int n, const double* alpha, const double* x, double* y,
|
||||
CPUContext* context) {
|
||||
EigenVectorMap<double>(y, n) = ConstEigenVectorMap<double>(x, n) * (*alpha);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
void Axpy<float, CPUContext>(const int N, const float alpha, const float* x,
|
||||
float* Y, CPUContext* context) {
|
||||
EigenVectorMap<float>(Y, N) += ConstEigenVectorMap<float>(x, N) * alpha;
|
||||
}
|
||||
|
||||
template <>
|
||||
void Axpy<double, CPUContext>(const int N, const double alpha, const double* x,
|
||||
double* Y, CPUContext* context) {
|
||||
EigenVectorMap<double>(Y, N) += ConstEigenVectorMap<double>(x, N) * alpha;
|
||||
}
|
||||
|
||||
template <>
|
||||
void Axpy<float, CPUContext>(const int N, const float* alpha, const float* x,
|
||||
float* Y, CPUContext* context) {
|
||||
EigenVectorMap<float>(Y, N) += ConstEigenVectorMap<float>(x, N) * (*alpha);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Axpy<double, CPUContext>(const int N, const double* alpha, const double* x,
|
||||
double* Y, CPUContext* context) {
|
||||
EigenVectorMap<double>(Y, N) += ConstEigenVectorMap<double>(x, N) * (*alpha);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Axpby<float, CPUContext>(const int N, const float alpha, const float* x,
|
||||
const float beta, float* y,
|
||||
CPUContext* context) {
|
||||
EigenVectorMap<float> y_vec(y, N);
|
||||
y_vec = y_vec * beta + ConstEigenVectorMap<float>(x, N) * alpha;
|
||||
}
|
||||
|
||||
template <>
|
||||
void Axpby<double, CPUContext>(const int N, const double alpha,
|
||||
const double* x, const double beta, double* y,
|
||||
CPUContext* context) {
|
||||
EigenVectorMap<double> y_vec(y, N);
|
||||
y_vec = y_vec * beta + ConstEigenVectorMap<double>(x, N) * alpha;
|
||||
}
|
||||
|
||||
template <>
|
||||
void Im2col<float, CPUContext, StorageOrder::NCHW>(
|
||||
const float* data_im, const int channels,
|
||||
|
@ -671,19 +671,6 @@ void Col2im<float, CUDAContext, StorageOrder::NHWC>(
|
||||
pad_t, pad_l, stride_h, stride_w, height_col, width_col, data_im);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
__global__ void CopyMatrixKernel(
|
||||
const int M, const int N, const T* A, const int lda,
|
||||
T* B, const int ldb) {
|
||||
CUDA_1D_KERNEL_LOOP(i, M * N) {
|
||||
int r = i / N;
|
||||
int c = i % N;
|
||||
B[r * ldb + c] = A[r * lda + c];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
void CopyMatrix<CUDAContext>(
|
||||
const size_t itemsize, const int M, const int N, const void* A,
|
||||
|
@ -1,83 +0,0 @@
|
||||
// This file implements a set of mkl functions when MKL is not available.
|
||||
#ifndef CAFFE2_UTILS_MKL_ALTERNATE_H_
|
||||
#define CAFFE2_UTILS_MKL_ALTERNATE_H_
|
||||
|
||||
#ifdef USE_MKL
|
||||
|
||||
#include <mkl.h>
|
||||
|
||||
#else // If use MKL, simply include the MKL header
|
||||
|
||||
#include <cmath>
|
||||
extern "C" {
|
||||
#include "caffe2/utils/cblas.h"
|
||||
}
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
// Functions that caffe uses but are not present if MKL is not linked.
|
||||
|
||||
// A simple way to define the vsl unary functions. The operation should
|
||||
// be in the form e.g. y[i] = sqrt(a[i])
|
||||
#define DEFINE_VSL_UNARY_FUNC(name, operation) \
|
||||
template<typename Dtype> \
|
||||
inline void v##name(const int n, const Dtype* a, Dtype* y) { \
|
||||
CAFFE_DCHECK_GT(n, 0); CAFFE_DCHECK(a); CAFFE_DCHECK(y); \
|
||||
for (int i = 0; i < n; ++i) { operation; } \
|
||||
} \
|
||||
inline void vs##name( \
|
||||
const int n, const float* a, float* y) { \
|
||||
v##name<float>(n, a, y); \
|
||||
} \
|
||||
inline void vd##name( \
|
||||
const int n, const double* a, double* y) { \
|
||||
v##name<double>(n, a, y); \
|
||||
}
|
||||
|
||||
DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
|
||||
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
|
||||
DEFINE_VSL_UNARY_FUNC(Ln, y[i] = std::log(a[i]));
|
||||
DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i]));
|
||||
|
||||
// A simple way to define the vsl unary functions with singular parameter b.
|
||||
// The operation should be in the form e.g. y[i] = pow(a[i], b)
|
||||
#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \
|
||||
template<typename Dtype> \
|
||||
inline void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \
|
||||
CAFFE_DCHECK_GT(n, 0); CAFFE_DCHECK(a); CAFFE_DCHECK(y); \
|
||||
for (int i = 0; i < n; ++i) { operation; } \
|
||||
} \
|
||||
inline void vs##name( \
|
||||
const int n, const float* a, const float b, float* y) { \
|
||||
v##name<float>(n, a, b, y); \
|
||||
} \
|
||||
inline void vd##name( \
|
||||
const int n, const double* a, const float b, double* y) { \
|
||||
v##name<double>(n, a, b, y); \
|
||||
}
|
||||
|
||||
DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b));
|
||||
|
||||
// A simple way to define the vsl binary functions. The operation should
|
||||
// be in the form e.g. y[i] = a[i] + b[i]
|
||||
#define DEFINE_VSL_BINARY_FUNC(name, operation) \
|
||||
template<typename Dtype> \
|
||||
inline void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \
|
||||
CAFFE_DCHECK_GT(n, 0); CAFFE_DCHECK(a); CAFFE_DCHECK(b); CAFFE_DCHECK(y); \
|
||||
for (int i = 0; i < n; ++i) { operation; } \
|
||||
} \
|
||||
inline void vs##name( \
|
||||
const int n, const float* a, const float* b, float* y) { \
|
||||
v##name<float>(n, a, b, y); \
|
||||
} \
|
||||
inline void vd##name( \
|
||||
const int n, const double* a, const double* b, double* y) { \
|
||||
v##name<double>(n, a, b, y); \
|
||||
}
|
||||
|
||||
DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]);
|
||||
DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]);
|
||||
DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]);
|
||||
DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]);
|
||||
|
||||
#endif // USE_MKL
|
||||
#endif // CAFFE2_UTILS_MKL_ALTERNATE_H_
|
33
third_party/blas/BREW
vendored
Normal file
33
third_party/blas/BREW
vendored
Normal file
@ -0,0 +1,33 @@
|
||||
# This BREW file is intended to be the central location that hosts all possible
|
||||
# BLAS backends. Note that all these are only linking flags, so if one of the
|
||||
# libraries is not used, don't bother installing it - Caffe2 will still build
|
||||
# normally.
|
||||
|
||||
# A catch-all target: all the targets should link to this instead of the
|
||||
# specific libraries below.
|
||||
cc_library(
|
||||
name = "blas",
|
||||
srcs = [],
|
||||
deps = ([] if Brewery.Env.Config.USE_EIGEN_FOR_BLAS
|
||||
else [":" + Brewery.Env.Config.BLAS_BACKEND]),
|
||||
)
|
||||
|
||||
# Atlas
|
||||
cc_thirdparty_target(
|
||||
name = "atlas",
|
||||
cc_obj_files = [ "-lcblas -latlas" ],
|
||||
)
|
||||
|
||||
# Intel MKL.
|
||||
cc_thirdparty_target(
|
||||
name = "mkl",
|
||||
cc_obj_files = [ "-lmkl_rt" ],
|
||||
)
|
||||
|
||||
# OpenBLAS
|
||||
cc_thirdparty_target(
|
||||
name = "openblas",
|
||||
cc_obj_files = [ "-lopenblas" ],
|
||||
)
|
||||
|
||||
# TODO: add the OS X veclib/Accelerate framework backend.
|
Reference in New Issue
Block a user