mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change vml.h to support sizes greater than 2**32 - 1
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17280 Differential Revision: D14154997 Pulled By: cpuhrsch fbshipit-source-id: c19b15d18da59c9ee87e82765d3244d2a4ef6729
This commit is contained in:
committed by
Facebook Github Bot
parent
2336f0ba06
commit
43f94077d8
@ -28,6 +28,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
#if AT_MKL_ENABLED() && !defined(__APPLE__)
|
||||
#include <mkl.h>
|
||||
@ -125,16 +126,42 @@ IMPLEMENT_VML_BUG(trunc)
|
||||
|
||||
#if AT_MKL_ENABLED() && !defined(__APPLE__)
|
||||
|
||||
#define IMPLEMENT_VML_MKL(op, mklop) \
|
||||
template <> \
|
||||
inline void v##op(float* out, const float* in, int64_t size) { \
|
||||
vms##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
||||
} \
|
||||
template <> \
|
||||
inline void v##op(double* out, const double* in, int64_t size) { \
|
||||
vmd##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
||||
// NB: LP64 MKL is the most commonly used and thus we assume it here. That means
|
||||
// we need to expect MKL_INT to be of type int, which implies int32_t in most
|
||||
// cases.
|
||||
static_assert(
|
||||
std::is_same<MKL_INT, int32_t>::value,
|
||||
"MKL_INT is assumed to be int32_t");
|
||||
#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
|
||||
template <> \
|
||||
inline void v##op(type * out, const type * in, int64_t size) { \
|
||||
int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
|
||||
if (size <= static_cast<int64_t>(max_mkl_ind)) { \
|
||||
vm##mkltype##mklop( \
|
||||
size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
||||
} else { \
|
||||
MKL_INT ind = 0; \
|
||||
int64_t chunks = size / max_mkl_ind; \
|
||||
int64_t rest = size % max_mkl_ind; \
|
||||
for (; ind < chunks; ind++) { \
|
||||
vm##mkltype##mklop( \
|
||||
max_mkl_ind, \
|
||||
in + ind * max_mkl_ind, \
|
||||
out + ind * max_mkl_ind, \
|
||||
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
||||
} \
|
||||
vm##mkltype##mklop( \
|
||||
rest, \
|
||||
in + ind * max_mkl_ind, \
|
||||
out + ind * max_mkl_ind, \
|
||||
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define IMPLEMENT_VML_MKL(op, mklop) \
|
||||
IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
|
||||
IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
|
||||
|
||||
// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple clang
|
||||
|
||||
IMPLEMENT_VML_MKL(abs, Abs)
|
||||
|
Reference in New Issue
Block a user