Compare commits

...

5 Commits

4 changed files with 33 additions and 8 deletions

View File

@ -4,7 +4,7 @@
set -ex
cd /
git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION:-v0.3.30}" --depth 1 --shallow-submodules
git clone https://github.com/Mousius/OpenBLAS.git -b "bgemm-optimisation" --depth 1 --shallow-submodules
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="

View File

@ -20,6 +20,14 @@ extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, doubl
extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
#ifdef BLAS_HAS_BGEMM
extern "C" void bgemm_(char *transa, char *transb, int *m, int *n, int *k,
const at::BFloat16 *alpha,
const at::BFloat16 *a, int *lda,
const at::BFloat16 *b, int *ldb,
const at::BFloat16 *beta,
at::BFloat16 *c, int *ldc);
#endif // BLAS_HAS_BGEMM
#ifdef BLAS_HAS_SBGEMM
extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k,
float *alpha,
@ -339,7 +347,7 @@ void gemm(
#ifdef __aarch64__
// MKLDNN also supports ARM for bf16, and the bypass is only
// currently intended for x86/x86_64.
const bool use_bf16_gemv_trans = false;
const bool use_bf16_gemv_trans = (m == 1 || n == 1);
#elif defined(__powerpc__)
const bool use_bf16_gemv_trans = false;
#else
@ -353,19 +361,30 @@ void gemm(
return;
}
#endif
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM)
#if AT_BUILD_WITH_BLAS() && (defined(BLAS_HAS_SBGEMM) || defined(BLAS_HAS_BGEMM))
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
float alpha_ = alpha, beta_ = beta;
int c_size = n_ * m_;
#if defined(BLAS_HAS_BGEMM)
at::BFloat16 alpha_ = c10::convert<at::BFloat16>(alpha);
at::BFloat16 beta_ = c10::convert<at::BFloat16>(beta);
bgemm_(&transa_, &transb_,
&m_, &n_, &k_,
&alpha_,
a, &lda_,
b, &ldb_,
&beta_,
c, &ldc_);
#else
// C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back.
int c_size = n_ * m_;
std::vector<float> float_v(c_size, 0.0f);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
float_v[j * m_ + i] = c10::convert<float>(c[j * ldc_ + i]);
}
}
float alpha_ = alpha, beta_ = beta;
sbgemm_(&transa_, &transb_,
&m_, &n_, &k_,
&alpha_,
@ -378,6 +397,7 @@ void gemm(
c[j * ldc_ + i] = c10::convert<at::BFloat16>(float_v[j * m_ + i]);
}
}
#endif // end of defined(BLAS_HAS_BGEMM)
return;
}
#endif

View File

@ -157,7 +157,7 @@ mkldnn_gemm(
bool bf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_bf32_matmul();
bool tf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_tf32_matmul();
if ( !(bf16_usable || fp16_usable || bf32_usable || tf32_usable) ||
(m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) {
((n * m <= 16384) && (n * k >= 4 * 2048 / m)) || (m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) {
return false;
}

View File

@ -336,13 +336,18 @@ ENDIF(NOT BLAS_FIND_QUIETLY)
# Do nothing is BLAS was found before
ENDIF(NOT BLAS_FOUND)
# Blas has bfloat16 support?
IF(BLAS_LIBRARIES)
INCLUDE(CheckFunctionExists)
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
set(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
check_function_exists("bgemm_" BLAS_HAS_BGEMM)
IF(BLAS_HAS_BGEMM)
add_compile_options(-DBLAS_HAS_BGEMM)
ENDIF(BLAS_HAS_BGEMM)
check_function_exists("sbgemm_" BLAS_HAS_SBGEMM)
set(CMAKE_REQUIRED_LIBRARIES)
IF(BLAS_HAS_SBGEMM)
add_compile_options(-DBLAS_HAS_SBGEMM)
ENDIF(BLAS_HAS_SBGEMM)
set(CMAKE_REQUIRED_LIBRARIES)
ENDIF(BLAS_LIBRARIES)