Change vml.h to support sizes greater than 2**32 - 1

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17280

Differential Revision: D14154997

Pulled By: cpuhrsch

fbshipit-source-id: c19b15d18da59c9ee87e82765d3244d2a4ef6729
This commit is contained in:
Christian Puhrsch
2019-03-01 16:53:23 -08:00
committed by Facebook Github Bot
parent 2336f0ba06
commit 43f94077d8

View File

@ -28,6 +28,7 @@
#include <cstdint>
#include <cstring>
#include <iostream>
#include <type_traits>
#if AT_MKL_ENABLED() && !defined(__APPLE__)
#include <mkl.h>
@ -125,16 +126,42 @@ IMPLEMENT_VML_BUG(trunc)
#if AT_MKL_ENABLED() && !defined(__APPLE__)
#define IMPLEMENT_VML_MKL(op, mklop) \
template <> \
inline void v##op(float* out, const float* in, int64_t size) { \
vms##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} \
template <> \
inline void v##op(double* out, const double* in, int64_t size) { \
vmd##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
// NB: LP64 MKL is the most commonly used and thus we assume it here. That means
// we need to expect MKL_INT to be of type int, which implies int32_t in most
// cases.
static_assert(
std::is_same<MKL_INT, int32_t>::value,
"MKL_INT is assumed to be int32_t");
#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
template <> \
inline void v##op(type * out, const type * in, int64_t size) { \
int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
if (size <= static_cast<int64_t>(max_mkl_ind)) { \
vm##mkltype##mklop( \
size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} else { \
MKL_INT ind = 0; \
int64_t chunks = size / max_mkl_ind; \
int64_t rest = size % max_mkl_ind; \
for (; ind < chunks; ind++) { \
vm##mkltype##mklop( \
max_mkl_ind, \
in + ind * max_mkl_ind, \
out + ind * max_mkl_ind, \
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} \
vm##mkltype##mklop( \
rest, \
in + ind * max_mkl_ind, \
out + ind * max_mkl_ind, \
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} \
}
#define IMPLEMENT_VML_MKL(op, mklop) \
IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple clang
IMPLEMENT_VML_MKL(abs, Abs)