mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
686 lines
24 KiB
C++
686 lines
24 KiB
C++
/*******************************************************************************
|
|
* Copyright 2018-2025 Intel Corporation
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*******************************************************************************/
|
|
|
|
#ifndef TEST_GEMM_COMMON_H
|
|
#define TEST_GEMM_COMMON_H
|
|
|
|
#include <cstdint>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include <type_traits>
|
|
|
|
#include "test_gemm_data_preparation.hpp"
|
|
#include "test_gemm_params.hpp"
|
|
#include "test_gemm_validation.hpp"
|
|
|
|
#include "common/dnnl_thread.hpp"
|
|
#include "dnnl_test_common.hpp"
|
|
#include "gtest/gtest.h"
|
|
|
|
#include "oneapi/dnnl/dnnl.h"
|
|
#include "oneapi/dnnl/dnnl_types.h"
|
|
|
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
|
|
#include "oneapi/dnnl/dnnl_ocl.hpp"
|
|
#endif
|
|
|
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
|
|
#include "oneapi/dnnl/dnnl_sycl.hpp"
|
|
#endif
|
|
|
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
|
#include "oneapi/dnnl/dnnl_threadpool.hpp"
|
|
#include "tests/test_thread.hpp"
|
|
#endif
|
|
|
|
#include "tests/test_isa_common.hpp"
|
|
|
|
#define CONCAT_WITH_UNDERSCORE_(a, b) a##_##b
|
|
#define CONCAT_WITH_UNDERSCORE(a, b) CONCAT_WITH_UNDERSCORE_(a, b)
|
|
|
|
#define INST_TEST_CASE_(str, ...) \
|
|
INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__))
|
|
#define INST_TEST_CASE(str, ...) \
|
|
INST_TEST_CASE_( \
|
|
CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__)
|
|
|
|
#define CPU_INST_TEST_CASE_(str, ...) \
|
|
CPU_INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__))
|
|
#define CPU_INST_TEST_CASE(str, ...) \
|
|
CPU_INST_TEST_CASE_( \
|
|
CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__)
|
|
|
|
// Declare bfloat16 GEMM interfaces for testing
|
|
extern "C" {
|
|
dnnl_status_t dnnl_gemm_bf16bf16f32(char transa, char transb, dnnl_dim_t M,
|
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const bfloat16_t *A,
|
|
dnnl_dim_t lda, const bfloat16_t *B, dnnl_dim_t ldb, float beta,
|
|
float *C, dnnl_dim_t ldc);
|
|
}
|
|
|
|
// Declare packed GEMM interfaces for testing
|
|
#include "src/cpu/gemm/gemm_pack.hpp"
|
|
|
|
namespace dnnl {
|
|
|
|
#if defined(DNNL_WTIH_SYCL)
|
|
bool is_memory_kind_buffer(const test_memory &mem) {
|
|
return sycl_interop::get_memory_kind(mem.get())
|
|
== sycl_interop::memory_kind::buffer;
|
|
}
|
|
#endif
|
|
|
|
/* Test implementation description.
|
|
* The testing steps looks as follows:
|
|
* 0. Prepare mapper_m and mapper_n <- details in test_gemm_data_preparation.hpp
|
|
* 1.a Generate random matrices A', B', C'
|
|
* 1.b Prepare matrices A, B, C based on A', B', and C' respectively
|
|
* 2. Compute C_calc := Op(M, N, K, A, B, C)
|
|
* 3. Compute C'_ref := Op_REF(M_test, N_test, K, A', B', C')
|
|
* 4. Expand C'_ref to C_ref, by applying mapper_m and mapper_n
|
|
* 5. Compare C_calc and C_ref
|
|
*/
|
|
|
|
template <typename a_dt, typename b_dt, typename c_dt>
|
|
struct dnnl_gemm_t {
|
|
static dnnl_status_t call(test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem) {
|
|
throw error(dnnl_runtime_error, "unknown gemm");
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<float16_t, float16_t, float16_t> {
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &) {
|
|
throw error(dnnl_runtime_error, "unknown gemm");
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<float, float, float> {
|
|
static dnnl_status_t call_packed(const test_params_t &p,
|
|
const test_memory &a_mem, const test_memory &b_mem,
|
|
const test_memory &c_mem) {
|
|
/* Alas, the internal API still uses Fortran notation.
|
|
* So in addition to the changes for pack API, we also need to take
|
|
* care of conversions and layouts */
|
|
|
|
using namespace dnnl::impl::cpu;
|
|
|
|
assert(p.alpha == 1.f);
|
|
|
|
/* Prepare for Fortran style, hence A <-> B */
|
|
char trans_a = p.transB, trans_b = p.transA;
|
|
|
|
int64_t m = p.N, n = p.M, k = p.K;
|
|
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
|
|
|
|
std::vector<float> a_pack_buf, b_pack_buf;
|
|
float *A = map_memory<float>(b_mem), *a_eff = A;
|
|
float *B = map_memory<float>(a_mem), *b_eff = B;
|
|
float *C = map_memory<float>(c_mem);
|
|
|
|
bool pack_a = p.pack_params.pack_b;
|
|
bool pack_b = p.pack_params.pack_a;
|
|
|
|
dnnl_status_t status = dnnl_success;
|
|
|
|
if (pack_a) {
|
|
size_t a_sz;
|
|
status = sgemm_pack_get_size("A", &trans_a, &trans_b, &m, &n, &k,
|
|
&lda, &ldb, &a_sz, &pack_a);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_a) {
|
|
a_pack_buf.resize(a_sz / sizeof(float));
|
|
a_eff = a_pack_buf.data();
|
|
|
|
status = sgemm_pack("A", &trans_a, &trans_b, &m, &n, &k, &lda,
|
|
&ldb, A, a_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_b) {
|
|
size_t b_sz;
|
|
status = sgemm_pack_get_size("B", &trans_a, &trans_b, &m, &n, &k,
|
|
&lda, &ldb, &b_sz, &pack_b);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_b) {
|
|
b_pack_buf.resize(b_sz / sizeof(float));
|
|
b_eff = b_pack_buf.data();
|
|
|
|
status = sgemm_pack("B", &trans_a, &trans_b, &m, &n, &k, &lda,
|
|
&ldb, B, b_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_a) trans_a = 'P';
|
|
if (pack_b) trans_b = 'P';
|
|
|
|
status = sgemm_compute(&trans_a, &trans_b, &m, &n, &k, a_eff, &lda,
|
|
b_eff, &ldb, &p.beta, C, &ldc);
|
|
|
|
return status;
|
|
}
|
|
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &) {
|
|
|
|
if (p.pack_params.pack_a || p.pack_params.pack_b)
|
|
return call_packed(p, a_mem, b_mem, c_mem);
|
|
|
|
auto A = map_memory<float>(a_mem);
|
|
auto B = map_memory<float>(b_mem);
|
|
auto C = map_memory<float>(c_mem);
|
|
|
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
|
static auto *st
|
|
= impl::testing_threadpool_utils::get_active_threadpool();
|
|
testing::scoped_tp_deactivation_t std;
|
|
return static_cast<dnnl_status_t>(dnnl::threadpool_interop::sgemm(
|
|
p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb,
|
|
p.beta, C, p.ldc, st));
|
|
#else
|
|
return dnnl_sgemm(p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda,
|
|
B, p.ldb, p.beta, C, p.ldc);
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<int8_t, int8_t, int32_t> {
|
|
static dnnl_status_t call_packed(const test_params_t &p,
|
|
const test_memory &a_mem, const test_memory &b_mem,
|
|
const test_memory &c_mem, const test_memory &oc_mem) {
|
|
/* Alas, the internal API still uses Fortran notation.
|
|
* So in addition to the changes for pack API, we also need to take
|
|
* care of conversions and layouts */
|
|
|
|
using namespace dnnl::impl::cpu;
|
|
|
|
assert(p.alpha == 1.f);
|
|
assert(p.igemm_params.oa() == 0);
|
|
assert(p.igemm_params.ob() == 0);
|
|
|
|
/* Prepare for Fortran style, hence A <-> B */
|
|
char trans_a = p.transB, trans_b = p.transA;
|
|
|
|
int64_t m = p.N, n = p.M, k = p.K;
|
|
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
|
|
|
|
int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A;
|
|
int8_t *B = map_memory<int8_t>(a_mem), *b_eff = B;
|
|
|
|
auto C = map_memory<int32_t>(c_mem);
|
|
auto oc = map_memory<int32_t>(oc_mem);
|
|
|
|
char offset_c = '\0';
|
|
switch (p.igemm_params.offsetc) {
|
|
case 'R': offset_c = 'C'; break;
|
|
case 'r': offset_c = 'c'; break;
|
|
case 'C': offset_c = 'R'; break;
|
|
case 'c': offset_c = 'r'; break;
|
|
default: offset_c = p.igemm_params.offsetc;
|
|
}
|
|
|
|
std::vector<int8_t> a_pack_buf;
|
|
std::vector<int8_t> b_pack_buf;
|
|
bool pack_a = p.pack_params.pack_b;
|
|
bool pack_b = p.pack_params.pack_a;
|
|
|
|
dnnl_status_t status = dnnl_success;
|
|
|
|
if (pack_a) {
|
|
size_t a_sz;
|
|
status = gemm_s8s8s32_pack_get_size(
|
|
"A", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_a) {
|
|
a_pack_buf.resize(a_sz);
|
|
a_eff = a_pack_buf.data();
|
|
|
|
status = gemm_s8s8s32_pack("A", &trans_a, &trans_b, &m, &n, &k,
|
|
&lda, &ldb, A, a_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_b) {
|
|
size_t b_sz;
|
|
|
|
status = gemm_s8s8s32_pack_get_size(
|
|
"B", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_b) {
|
|
b_pack_buf.resize(b_sz);
|
|
b_eff = b_pack_buf.data();
|
|
|
|
status = gemm_s8s8s32_pack("B", &trans_a, &trans_b, &m, &n, &k,
|
|
&lda, &ldb, B, b_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_a) trans_a = 'P';
|
|
if (pack_b) trans_b = 'P';
|
|
|
|
status = gemm_s8s8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k,
|
|
a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc);
|
|
|
|
return status;
|
|
}
|
|
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &oc_mem) {
|
|
|
|
if (p.pack_params.pack_a || p.pack_params.pack_b)
|
|
return call_packed(p, a_mem, b_mem, c_mem, oc_mem);
|
|
|
|
auto A = map_memory<int8_t>(a_mem);
|
|
auto B = map_memory<int8_t>(b_mem);
|
|
auto C = map_memory<int32_t>(c_mem);
|
|
auto oc = map_memory<int32_t>(oc_mem);
|
|
int8_t oa = p.igemm_params.oa();
|
|
int8_t ob = p.igemm_params.ob();
|
|
|
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
|
static auto *st
|
|
= impl::testing_threadpool_utils::get_active_threadpool();
|
|
testing::scoped_tp_deactivation_t std;
|
|
return static_cast<dnnl_status_t>(
|
|
dnnl::threadpool_interop::gemm_s8s8s32(p.transA, p.transB,
|
|
p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A,
|
|
p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st));
|
|
#else
|
|
return dnnl_gemm_s8s8s32(p.transA, p.transB, p.igemm_params.offsetc,
|
|
p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C,
|
|
p.ldc, oc);
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<int8_t, uint8_t, int32_t> {
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &oc_mem) {
|
|
throw error(dnnl_runtime_error, "unknown gemm");
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<uint8_t, uint8_t, int32_t> {
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &oc_mem) {
|
|
|
|
throw error(dnnl_runtime_error, "unknown gemm");
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<uint8_t, int8_t, int32_t> {
|
|
static dnnl_status_t call_packed(const test_params_t &p,
|
|
const test_memory &a_mem, const test_memory &b_mem,
|
|
const test_memory &c_mem, const test_memory &oc_mem) {
|
|
/* Alas, the internal API still uses Fortran notation.
|
|
* So in addition to the changes for pack API, we also need to take
|
|
* care of conversions and layouts */
|
|
|
|
using namespace dnnl::impl::cpu;
|
|
|
|
assert(p.alpha == 1.f);
|
|
assert(p.igemm_params.oa() == 0);
|
|
assert(p.igemm_params.ob() == 0);
|
|
|
|
/* Prepare for Fortran style, hence A <-> B */
|
|
char trans_a = p.transB, trans_b = p.transA;
|
|
|
|
int64_t m = p.N, n = p.M, k = p.K;
|
|
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
|
|
|
|
int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A;
|
|
uint8_t *B = map_memory<uint8_t>(a_mem), *b_eff = B;
|
|
|
|
auto C = map_memory<int32_t>(c_mem);
|
|
auto oc = map_memory<int32_t>(oc_mem);
|
|
|
|
char offset_c = '\0';
|
|
switch (p.igemm_params.offsetc) {
|
|
case 'R': offset_c = 'C'; break;
|
|
case 'r': offset_c = 'c'; break;
|
|
case 'C': offset_c = 'R'; break;
|
|
case 'c': offset_c = 'r'; break;
|
|
default: offset_c = p.igemm_params.offsetc;
|
|
}
|
|
|
|
std::vector<int8_t> a_pack_buf;
|
|
std::vector<uint8_t> b_pack_buf;
|
|
bool pack_a = p.pack_params.pack_b;
|
|
bool pack_b = p.pack_params.pack_a;
|
|
|
|
dnnl_status_t status = dnnl_success;
|
|
|
|
if (pack_a) {
|
|
size_t a_sz;
|
|
status = gemm_s8u8s32_pack_get_size(
|
|
"A", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_a) {
|
|
a_pack_buf.resize(a_sz);
|
|
a_eff = a_pack_buf.data();
|
|
|
|
status = gemm_s8u8s32_pack("A", &trans_a, &trans_b, &m, &n, &k,
|
|
&lda, &ldb, A, a_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_b) {
|
|
size_t b_sz;
|
|
|
|
status = gemm_s8u8s32_pack_get_size(
|
|
"B", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_b) {
|
|
b_pack_buf.resize(b_sz);
|
|
b_eff = b_pack_buf.data();
|
|
|
|
status = gemm_s8u8s32_pack("B", &trans_a, &trans_b, &m, &n, &k,
|
|
&lda, &ldb, B, b_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_a) trans_a = 'P';
|
|
if (pack_b) trans_b = 'P';
|
|
|
|
status = gemm_s8u8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k,
|
|
a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc);
|
|
|
|
return status;
|
|
}
|
|
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &oc_mem) {
|
|
assert(p.igemm_params.oa() >= 0);
|
|
|
|
if (p.pack_params.pack_a || p.pack_params.pack_b)
|
|
return call_packed(p, a_mem, b_mem, c_mem, oc_mem);
|
|
|
|
auto A = map_memory<uint8_t>(a_mem);
|
|
auto B = map_memory<int8_t>(b_mem);
|
|
auto C = map_memory<int32_t>(c_mem);
|
|
auto oc = map_memory<int32_t>(oc_mem);
|
|
uint8_t oa = (uint8_t)p.igemm_params.oa();
|
|
int8_t ob = p.igemm_params.ob();
|
|
|
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
|
static auto *st
|
|
= impl::testing_threadpool_utils::get_active_threadpool();
|
|
testing::scoped_tp_deactivation_t std;
|
|
return static_cast<dnnl_status_t>(
|
|
dnnl::threadpool_interop::gemm_u8s8s32(p.transA, p.transB,
|
|
p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A,
|
|
p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st));
|
|
#else
|
|
return dnnl_gemm_u8s8s32(p.transA, p.transB, p.igemm_params.offsetc,
|
|
p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C,
|
|
p.ldc, oc);
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<float16_t, float16_t, float> {
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &) {
|
|
return dnnl_unimplemented;
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<bfloat16_t, bfloat16_t, float> {
|
|
static dnnl_status_t call_packed(const test_params_t &p,
|
|
const test_memory &a_mem, const test_memory &b_mem,
|
|
const test_memory &c_mem) {
|
|
/* Alas, the internal API still uses Fortran notation.
|
|
* So in addition to the changes for pack API, we also need to take
|
|
* care of conversions and layouts */
|
|
|
|
using namespace dnnl::impl::cpu;
|
|
|
|
assert(p.alpha == 1.f);
|
|
|
|
/* Prepare for Fortran style, hence A <-> B */
|
|
char trans_a = p.transB, trans_b = p.transA;
|
|
|
|
int64_t m = p.N, n = p.M, k = p.K;
|
|
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
|
|
|
|
std::vector<bfloat16_t> a_pack_buf, b_pack_buf;
|
|
bfloat16_t *A = map_memory<bfloat16_t>(b_mem), *a_eff = A;
|
|
bfloat16_t *B = map_memory<bfloat16_t>(a_mem), *b_eff = B;
|
|
float *C = map_memory<float>(c_mem);
|
|
|
|
bool pack_a = p.pack_params.pack_b;
|
|
bool pack_b = p.pack_params.pack_a;
|
|
|
|
dnnl_status_t status = dnnl_success;
|
|
|
|
if (pack_a) {
|
|
size_t a_sz;
|
|
status = gemm_bf16bf16f32_pack_get_size("A", &trans_a, &trans_b, &m,
|
|
&n, &k, &lda, &ldb, &a_sz, &pack_a);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_a) {
|
|
a_pack_buf.resize(a_sz / sizeof(*a_eff));
|
|
a_eff = a_pack_buf.data();
|
|
|
|
status = gemm_bf16bf16f32_pack("A", &trans_a, &trans_b, &m, &n,
|
|
&k, &lda, &ldb, A, a_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_b) {
|
|
size_t b_sz;
|
|
status = gemm_bf16bf16f32_pack_get_size("B", &trans_a, &trans_b, &m,
|
|
&n, &k, &lda, &ldb, &b_sz, &pack_b);
|
|
if (status != dnnl_success) return status;
|
|
|
|
if (pack_b) {
|
|
b_pack_buf.resize(b_sz / sizeof(*b_eff));
|
|
b_eff = b_pack_buf.data();
|
|
|
|
status = gemm_bf16bf16f32_pack("B", &trans_a, &trans_b, &m, &n,
|
|
&k, &lda, &ldb, B, b_eff);
|
|
if (status != dnnl_success) return status;
|
|
}
|
|
}
|
|
|
|
if (pack_a) trans_a = 'P';
|
|
if (pack_b) trans_b = 'P';
|
|
|
|
status = gemm_bf16bf16f32_compute(&trans_a, &trans_b, &m, &n, &k, a_eff,
|
|
&lda, b_eff, &ldb, &p.beta, C, &ldc);
|
|
|
|
return status;
|
|
}
|
|
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &) {
|
|
if (p.pack_params.pack_a || p.pack_params.pack_b)
|
|
return call_packed(p, a_mem, b_mem, c_mem);
|
|
|
|
auto A = map_memory<bfloat16_t>(a_mem);
|
|
auto B = map_memory<bfloat16_t>(b_mem);
|
|
auto C = map_memory<float>(c_mem);
|
|
return dnnl_gemm_bf16bf16f32(p.transA, p.transB, p.M, p.N, p.K, p.alpha,
|
|
A, p.lda, B, p.ldb, p.beta, C, p.ldc);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_gemm_t<bfloat16_t, bfloat16_t, bfloat16_t> {
|
|
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
|
|
const test_memory &b_mem, const test_memory &c_mem,
|
|
const test_memory &) {
|
|
return dnnl_unimplemented;
|
|
}
|
|
};
|
|
|
|
template <typename a_dt, typename b_dt, typename c_dt>
|
|
struct run_test_gemm_t {
|
|
static void call(const test_params_t &p) {
|
|
if (p.expect_to_fail) {
|
|
engine eng = get_test_engine();
|
|
test_memory zero_mem({}, eng);
|
|
auto status = dnnl_gemm_t<a_dt, b_dt, c_dt>::call(
|
|
p, zero_mem, zero_mem, zero_mem, zero_mem);
|
|
if (status != dnnl_success)
|
|
throw error(status, "oneDNN gemm returned error");
|
|
return;
|
|
}
|
|
|
|
engine eng = get_test_engine();
|
|
test_gemm_data_t gemm_data;
|
|
prepare_data_for_gemm_testing<a_dt, b_dt, c_dt>(p, gemm_data, eng);
|
|
|
|
auto status = dnnl_gemm_t<a_dt, b_dt, c_dt>::call(p, *gemm_data.a_mem,
|
|
*gemm_data.b_mem, *gemm_data.c_mem, *gemm_data.oc_mem);
|
|
|
|
if (status == dnnl_success) {
|
|
validate<a_dt, b_dt, c_dt>(p, gemm_data);
|
|
}
|
|
|
|
if (status != dnnl_success)
|
|
throw error(status, "oneDNN gemm returned error");
|
|
}
|
|
};
|
|
|
|
template <typename a_dt, typename b_dt, typename c_dt>
|
|
class gemm_test_common_t : public ::testing::TestWithParam<test_params_t> {
|
|
protected:
|
|
void SetUp() override {
|
|
const auto &p = ::testing::TestWithParam<test_params_t>::GetParam();
|
|
|
|
SKIP_IF(get_test_engine_kind() == engine::kind::gpu,
|
|
"GPU GEMM not implemented.");
|
|
|
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
|
|
SKIP_IF(get_test_engine_kind() == engine::kind::cpu,
|
|
"SYCL CPU GEMM not implemented.");
|
|
#endif
|
|
|
|
bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0);
|
|
SKIP_IF(!zero_off && get_test_engine_kind() == engine::kind::cpu,
|
|
"CPU does not support non-zero offsets.");
|
|
|
|
SKIP_IF(unsupported_data_type(data_traits_t<a_dt>::data_type),
|
|
"Engine does not support this data type.");
|
|
|
|
bool is_f16
|
|
= (data_traits_t<a_dt>::data_type == memory::data_type::f16);
|
|
SKIP_IF(is_f16 && get_test_engine_kind() == engine::kind::cpu,
|
|
"CPU does not support f16 data type.");
|
|
|
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
|
|
SKIP_IF(get_test_engine_kind() == engine::kind::cpu,
|
|
"SYCL CPU GEMM not implemented.");
|
|
#endif
|
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
|
|
SKIP_IF(get_test_engine_kind() == engine::kind::gpu
|
|
&& (data_traits_t<a_dt>::data_type
|
|
== memory::data_type::u8
|
|
|| data_traits_t<a_dt>::data_type
|
|
== memory::data_type::s8),
|
|
"SYCL GPU int GEMM not implemented.");
|
|
SKIP_IF_CUDA(true, "Test not supported in CUDA backend");
|
|
#endif
|
|
|
|
#if DNNL_X64
|
|
bool is_bf16bf16f32 = true
|
|
&& data_traits_t<a_dt>::data_type == memory::data_type::bf16
|
|
&& data_traits_t<b_dt>::data_type == memory::data_type::bf16
|
|
&& data_traits_t<c_dt>::data_type == memory::data_type::f32;
|
|
|
|
SKIP_IF(is_bf16bf16f32 && get_test_engine_kind() == engine::kind::cpu
|
|
&& !dnnl::mayiuse(cpu_isa::avx512_core),
|
|
"Skip test for systems that do not support avx512_core.");
|
|
#endif
|
|
|
|
bool pack = (p.pack_params.pack_a || p.pack_params.pack_b);
|
|
SKIP_IF(!DNNL_X64 && pack,
|
|
"Packed GEMM does not support non-x64 CPUs.");
|
|
SKIP_IF((p.alpha != 1.f || p.igemm_params.oa() != 0
|
|
|| p.igemm_params.ob() != 0)
|
|
&& pack,
|
|
"Packed GEMM doesn't support alpha or non-zero offset{A,B}.");
|
|
SKIP_IF(data_traits_t<b_dt>::data_type == memory::data_type::u8
|
|
&& get_test_engine_kind() == engine::kind::cpu,
|
|
"CPU does not support s8u8s32 and u8u8s32 GEMM.");
|
|
SKIP_IF(data_traits_t<c_dt>::data_type == memory::data_type::bf16
|
|
&& get_test_engine_kind() == engine::kind::cpu,
|
|
"CPU does not support bf16bf16bf16 GEMM.");
|
|
|
|
catch_expected_failures(
|
|
[&]() { Test(); }, p.expect_to_fail, p.expected_status, false);
|
|
}
|
|
void Test() {
|
|
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
|
testing::scoped_tp_activation_t sta;
|
|
#endif
|
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
|
|
if (get_test_engine_kind() == engine::kind::gpu) {
|
|
const auto &p = ::testing::TestWithParam<test_params_t>::GetParam();
|
|
|
|
#if defined(TEST_DNNL_DPCPP_BUFFER)
|
|
// Test SYCL buffer interfaces
|
|
run_test_gemm_t<a_dt, b_dt, c_dt>::call(p);
|
|
#else
|
|
// Test SYCL USM interfaces
|
|
bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0);
|
|
SKIP_IF(!zero_off, "USM interfaces do not support offsets.");
|
|
|
|
run_test_gemm_t<a_dt, b_dt, c_dt>::call(p);
|
|
#endif
|
|
|
|
return;
|
|
}
|
|
#endif
|
|
const auto &p = ::testing::TestWithParam<test_params_t>::GetParam();
|
|
run_test_gemm_t<a_dt, b_dt, c_dt>::call(p);
|
|
}
|
|
};
|
|
} // namespace dnnl
|
|
#endif
|