Files
oneDNN/tests/gtests/test_gemm_common.hpp
2025-03-18 15:13:13 -07:00

686 lines
24 KiB
C++

/*******************************************************************************
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef TEST_GEMM_COMMON_H
#define TEST_GEMM_COMMON_H
#include <cstdint>
#include <utility>
#include <vector>
#include <type_traits>
#include "test_gemm_data_preparation.hpp"
#include "test_gemm_params.hpp"
#include "test_gemm_validation.hpp"
#include "common/dnnl_thread.hpp"
#include "dnnl_test_common.hpp"
#include "gtest/gtest.h"
#include "oneapi/dnnl/dnnl.h"
#include "oneapi/dnnl/dnnl_types.h"
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
#include "oneapi/dnnl/dnnl_ocl.hpp"
#endif
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
#include "oneapi/dnnl/dnnl_sycl.hpp"
#endif
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
#include "oneapi/dnnl/dnnl_threadpool.hpp"
#include "tests/test_thread.hpp"
#endif
#include "tests/test_isa_common.hpp"
#define CONCAT_WITH_UNDERSCORE_(a, b) a##_##b
#define CONCAT_WITH_UNDERSCORE(a, b) CONCAT_WITH_UNDERSCORE_(a, b)
#define INST_TEST_CASE_(str, ...) \
INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__))
#define INST_TEST_CASE(str, ...) \
INST_TEST_CASE_( \
CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__)
#define CPU_INST_TEST_CASE_(str, ...) \
CPU_INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__))
#define CPU_INST_TEST_CASE(str, ...) \
CPU_INST_TEST_CASE_( \
CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__)
// Declare bfloat16 GEMM interfaces for testing
extern "C" {
dnnl_status_t dnnl_gemm_bf16bf16f32(char transa, char transb, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha, const bfloat16_t *A,
dnnl_dim_t lda, const bfloat16_t *B, dnnl_dim_t ldb, float beta,
float *C, dnnl_dim_t ldc);
}
// Declare packed GEMM interfaces for testing
#include "src/cpu/gemm/gemm_pack.hpp"
namespace dnnl {
#if defined(DNNL_WTIH_SYCL)
bool is_memory_kind_buffer(const test_memory &mem) {
return sycl_interop::get_memory_kind(mem.get())
== sycl_interop::memory_kind::buffer;
}
#endif
/* Test implementation description.
* The testing steps looks as follows:
* 0. Prepare mapper_m and mapper_n <- details in test_gemm_data_preparation.hpp
* 1.a Generate random matrices A', B', C'
* 1.b Prepare matrices A, B, C based on A', B', and C' respectively
* 2. Compute C_calc := Op(M, N, K, A, B, C)
* 3. Compute C'_ref := Op_REF(M_test, N_test, K, A', B', C')
* 4. Expand C'_ref to C_ref, by applying mapper_m and mapper_n
* 5. Compare C_calc and C_ref
*/
template <typename a_dt, typename b_dt, typename c_dt>
struct dnnl_gemm_t {
static dnnl_status_t call(test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem) {
throw error(dnnl_runtime_error, "unknown gemm");
}
};
template <>
struct dnnl_gemm_t<float16_t, float16_t, float16_t> {
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &) {
throw error(dnnl_runtime_error, "unknown gemm");
}
};
template <>
struct dnnl_gemm_t<float, float, float> {
static dnnl_status_t call_packed(const test_params_t &p,
const test_memory &a_mem, const test_memory &b_mem,
const test_memory &c_mem) {
/* Alas, the internal API still uses Fortran notation.
* So in addition to the changes for pack API, we also need to take
* care of conversions and layouts */
using namespace dnnl::impl::cpu;
assert(p.alpha == 1.f);
/* Prepare for Fortran style, hence A <-> B */
char trans_a = p.transB, trans_b = p.transA;
int64_t m = p.N, n = p.M, k = p.K;
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
std::vector<float> a_pack_buf, b_pack_buf;
float *A = map_memory<float>(b_mem), *a_eff = A;
float *B = map_memory<float>(a_mem), *b_eff = B;
float *C = map_memory<float>(c_mem);
bool pack_a = p.pack_params.pack_b;
bool pack_b = p.pack_params.pack_a;
dnnl_status_t status = dnnl_success;
if (pack_a) {
size_t a_sz;
status = sgemm_pack_get_size("A", &trans_a, &trans_b, &m, &n, &k,
&lda, &ldb, &a_sz, &pack_a);
if (status != dnnl_success) return status;
if (pack_a) {
a_pack_buf.resize(a_sz / sizeof(float));
a_eff = a_pack_buf.data();
status = sgemm_pack("A", &trans_a, &trans_b, &m, &n, &k, &lda,
&ldb, A, a_eff);
if (status != dnnl_success) return status;
}
}
if (pack_b) {
size_t b_sz;
status = sgemm_pack_get_size("B", &trans_a, &trans_b, &m, &n, &k,
&lda, &ldb, &b_sz, &pack_b);
if (status != dnnl_success) return status;
if (pack_b) {
b_pack_buf.resize(b_sz / sizeof(float));
b_eff = b_pack_buf.data();
status = sgemm_pack("B", &trans_a, &trans_b, &m, &n, &k, &lda,
&ldb, B, b_eff);
if (status != dnnl_success) return status;
}
}
if (pack_a) trans_a = 'P';
if (pack_b) trans_b = 'P';
status = sgemm_compute(&trans_a, &trans_b, &m, &n, &k, a_eff, &lda,
b_eff, &ldb, &p.beta, C, &ldc);
return status;
}
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &) {
if (p.pack_params.pack_a || p.pack_params.pack_b)
return call_packed(p, a_mem, b_mem, c_mem);
auto A = map_memory<float>(a_mem);
auto B = map_memory<float>(b_mem);
auto C = map_memory<float>(c_mem);
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
static auto *st
= impl::testing_threadpool_utils::get_active_threadpool();
testing::scoped_tp_deactivation_t std;
return static_cast<dnnl_status_t>(dnnl::threadpool_interop::sgemm(
p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb,
p.beta, C, p.ldc, st));
#else
return dnnl_sgemm(p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda,
B, p.ldb, p.beta, C, p.ldc);
#endif
}
};
template <>
struct dnnl_gemm_t<int8_t, int8_t, int32_t> {
static dnnl_status_t call_packed(const test_params_t &p,
const test_memory &a_mem, const test_memory &b_mem,
const test_memory &c_mem, const test_memory &oc_mem) {
/* Alas, the internal API still uses Fortran notation.
* So in addition to the changes for pack API, we also need to take
* care of conversions and layouts */
using namespace dnnl::impl::cpu;
assert(p.alpha == 1.f);
assert(p.igemm_params.oa() == 0);
assert(p.igemm_params.ob() == 0);
/* Prepare for Fortran style, hence A <-> B */
char trans_a = p.transB, trans_b = p.transA;
int64_t m = p.N, n = p.M, k = p.K;
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A;
int8_t *B = map_memory<int8_t>(a_mem), *b_eff = B;
auto C = map_memory<int32_t>(c_mem);
auto oc = map_memory<int32_t>(oc_mem);
char offset_c = '\0';
switch (p.igemm_params.offsetc) {
case 'R': offset_c = 'C'; break;
case 'r': offset_c = 'c'; break;
case 'C': offset_c = 'R'; break;
case 'c': offset_c = 'r'; break;
default: offset_c = p.igemm_params.offsetc;
}
std::vector<int8_t> a_pack_buf;
std::vector<int8_t> b_pack_buf;
bool pack_a = p.pack_params.pack_b;
bool pack_b = p.pack_params.pack_a;
dnnl_status_t status = dnnl_success;
if (pack_a) {
size_t a_sz;
status = gemm_s8s8s32_pack_get_size(
"A", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz);
if (status != dnnl_success) return status;
if (pack_a) {
a_pack_buf.resize(a_sz);
a_eff = a_pack_buf.data();
status = gemm_s8s8s32_pack("A", &trans_a, &trans_b, &m, &n, &k,
&lda, &ldb, A, a_eff);
if (status != dnnl_success) return status;
}
}
if (pack_b) {
size_t b_sz;
status = gemm_s8s8s32_pack_get_size(
"B", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz);
if (status != dnnl_success) return status;
if (pack_b) {
b_pack_buf.resize(b_sz);
b_eff = b_pack_buf.data();
status = gemm_s8s8s32_pack("B", &trans_a, &trans_b, &m, &n, &k,
&lda, &ldb, B, b_eff);
if (status != dnnl_success) return status;
}
}
if (pack_a) trans_a = 'P';
if (pack_b) trans_b = 'P';
status = gemm_s8s8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k,
a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc);
return status;
}
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &oc_mem) {
if (p.pack_params.pack_a || p.pack_params.pack_b)
return call_packed(p, a_mem, b_mem, c_mem, oc_mem);
auto A = map_memory<int8_t>(a_mem);
auto B = map_memory<int8_t>(b_mem);
auto C = map_memory<int32_t>(c_mem);
auto oc = map_memory<int32_t>(oc_mem);
int8_t oa = p.igemm_params.oa();
int8_t ob = p.igemm_params.ob();
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
static auto *st
= impl::testing_threadpool_utils::get_active_threadpool();
testing::scoped_tp_deactivation_t std;
return static_cast<dnnl_status_t>(
dnnl::threadpool_interop::gemm_s8s8s32(p.transA, p.transB,
p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A,
p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st));
#else
return dnnl_gemm_s8s8s32(p.transA, p.transB, p.igemm_params.offsetc,
p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C,
p.ldc, oc);
#endif
}
};
template <>
struct dnnl_gemm_t<int8_t, uint8_t, int32_t> {
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &oc_mem) {
throw error(dnnl_runtime_error, "unknown gemm");
}
};
template <>
struct dnnl_gemm_t<uint8_t, uint8_t, int32_t> {
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &oc_mem) {
throw error(dnnl_runtime_error, "unknown gemm");
}
};
template <>
struct dnnl_gemm_t<uint8_t, int8_t, int32_t> {
static dnnl_status_t call_packed(const test_params_t &p,
const test_memory &a_mem, const test_memory &b_mem,
const test_memory &c_mem, const test_memory &oc_mem) {
/* Alas, the internal API still uses Fortran notation.
* So in addition to the changes for pack API, we also need to take
* care of conversions and layouts */
using namespace dnnl::impl::cpu;
assert(p.alpha == 1.f);
assert(p.igemm_params.oa() == 0);
assert(p.igemm_params.ob() == 0);
/* Prepare for Fortran style, hence A <-> B */
char trans_a = p.transB, trans_b = p.transA;
int64_t m = p.N, n = p.M, k = p.K;
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A;
uint8_t *B = map_memory<uint8_t>(a_mem), *b_eff = B;
auto C = map_memory<int32_t>(c_mem);
auto oc = map_memory<int32_t>(oc_mem);
char offset_c = '\0';
switch (p.igemm_params.offsetc) {
case 'R': offset_c = 'C'; break;
case 'r': offset_c = 'c'; break;
case 'C': offset_c = 'R'; break;
case 'c': offset_c = 'r'; break;
default: offset_c = p.igemm_params.offsetc;
}
std::vector<int8_t> a_pack_buf;
std::vector<uint8_t> b_pack_buf;
bool pack_a = p.pack_params.pack_b;
bool pack_b = p.pack_params.pack_a;
dnnl_status_t status = dnnl_success;
if (pack_a) {
size_t a_sz;
status = gemm_s8u8s32_pack_get_size(
"A", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz);
if (status != dnnl_success) return status;
if (pack_a) {
a_pack_buf.resize(a_sz);
a_eff = a_pack_buf.data();
status = gemm_s8u8s32_pack("A", &trans_a, &trans_b, &m, &n, &k,
&lda, &ldb, A, a_eff);
if (status != dnnl_success) return status;
}
}
if (pack_b) {
size_t b_sz;
status = gemm_s8u8s32_pack_get_size(
"B", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz);
if (status != dnnl_success) return status;
if (pack_b) {
b_pack_buf.resize(b_sz);
b_eff = b_pack_buf.data();
status = gemm_s8u8s32_pack("B", &trans_a, &trans_b, &m, &n, &k,
&lda, &ldb, B, b_eff);
if (status != dnnl_success) return status;
}
}
if (pack_a) trans_a = 'P';
if (pack_b) trans_b = 'P';
status = gemm_s8u8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k,
a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc);
return status;
}
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &oc_mem) {
assert(p.igemm_params.oa() >= 0);
if (p.pack_params.pack_a || p.pack_params.pack_b)
return call_packed(p, a_mem, b_mem, c_mem, oc_mem);
auto A = map_memory<uint8_t>(a_mem);
auto B = map_memory<int8_t>(b_mem);
auto C = map_memory<int32_t>(c_mem);
auto oc = map_memory<int32_t>(oc_mem);
uint8_t oa = (uint8_t)p.igemm_params.oa();
int8_t ob = p.igemm_params.ob();
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
static auto *st
= impl::testing_threadpool_utils::get_active_threadpool();
testing::scoped_tp_deactivation_t std;
return static_cast<dnnl_status_t>(
dnnl::threadpool_interop::gemm_u8s8s32(p.transA, p.transB,
p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A,
p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st));
#else
return dnnl_gemm_u8s8s32(p.transA, p.transB, p.igemm_params.offsetc,
p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C,
p.ldc, oc);
#endif
}
};
template <>
struct dnnl_gemm_t<float16_t, float16_t, float> {
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &) {
return dnnl_unimplemented;
}
};
template <>
struct dnnl_gemm_t<bfloat16_t, bfloat16_t, float> {
static dnnl_status_t call_packed(const test_params_t &p,
const test_memory &a_mem, const test_memory &b_mem,
const test_memory &c_mem) {
/* Alas, the internal API still uses Fortran notation.
* So in addition to the changes for pack API, we also need to take
* care of conversions and layouts */
using namespace dnnl::impl::cpu;
assert(p.alpha == 1.f);
/* Prepare for Fortran style, hence A <-> B */
char trans_a = p.transB, trans_b = p.transA;
int64_t m = p.N, n = p.M, k = p.K;
int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
std::vector<bfloat16_t> a_pack_buf, b_pack_buf;
bfloat16_t *A = map_memory<bfloat16_t>(b_mem), *a_eff = A;
bfloat16_t *B = map_memory<bfloat16_t>(a_mem), *b_eff = B;
float *C = map_memory<float>(c_mem);
bool pack_a = p.pack_params.pack_b;
bool pack_b = p.pack_params.pack_a;
dnnl_status_t status = dnnl_success;
if (pack_a) {
size_t a_sz;
status = gemm_bf16bf16f32_pack_get_size("A", &trans_a, &trans_b, &m,
&n, &k, &lda, &ldb, &a_sz, &pack_a);
if (status != dnnl_success) return status;
if (pack_a) {
a_pack_buf.resize(a_sz / sizeof(*a_eff));
a_eff = a_pack_buf.data();
status = gemm_bf16bf16f32_pack("A", &trans_a, &trans_b, &m, &n,
&k, &lda, &ldb, A, a_eff);
if (status != dnnl_success) return status;
}
}
if (pack_b) {
size_t b_sz;
status = gemm_bf16bf16f32_pack_get_size("B", &trans_a, &trans_b, &m,
&n, &k, &lda, &ldb, &b_sz, &pack_b);
if (status != dnnl_success) return status;
if (pack_b) {
b_pack_buf.resize(b_sz / sizeof(*b_eff));
b_eff = b_pack_buf.data();
status = gemm_bf16bf16f32_pack("B", &trans_a, &trans_b, &m, &n,
&k, &lda, &ldb, B, b_eff);
if (status != dnnl_success) return status;
}
}
if (pack_a) trans_a = 'P';
if (pack_b) trans_b = 'P';
status = gemm_bf16bf16f32_compute(&trans_a, &trans_b, &m, &n, &k, a_eff,
&lda, b_eff, &ldb, &p.beta, C, &ldc);
return status;
}
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &) {
if (p.pack_params.pack_a || p.pack_params.pack_b)
return call_packed(p, a_mem, b_mem, c_mem);
auto A = map_memory<bfloat16_t>(a_mem);
auto B = map_memory<bfloat16_t>(b_mem);
auto C = map_memory<float>(c_mem);
return dnnl_gemm_bf16bf16f32(p.transA, p.transB, p.M, p.N, p.K, p.alpha,
A, p.lda, B, p.ldb, p.beta, C, p.ldc);
}
};
template <>
struct dnnl_gemm_t<bfloat16_t, bfloat16_t, bfloat16_t> {
static dnnl_status_t call(const test_params_t &p, const test_memory &a_mem,
const test_memory &b_mem, const test_memory &c_mem,
const test_memory &) {
return dnnl_unimplemented;
}
};
template <typename a_dt, typename b_dt, typename c_dt>
struct run_test_gemm_t {
static void call(const test_params_t &p) {
if (p.expect_to_fail) {
engine eng = get_test_engine();
test_memory zero_mem({}, eng);
auto status = dnnl_gemm_t<a_dt, b_dt, c_dt>::call(
p, zero_mem, zero_mem, zero_mem, zero_mem);
if (status != dnnl_success)
throw error(status, "oneDNN gemm returned error");
return;
}
engine eng = get_test_engine();
test_gemm_data_t gemm_data;
prepare_data_for_gemm_testing<a_dt, b_dt, c_dt>(p, gemm_data, eng);
auto status = dnnl_gemm_t<a_dt, b_dt, c_dt>::call(p, *gemm_data.a_mem,
*gemm_data.b_mem, *gemm_data.c_mem, *gemm_data.oc_mem);
if (status == dnnl_success) {
validate<a_dt, b_dt, c_dt>(p, gemm_data);
}
if (status != dnnl_success)
throw error(status, "oneDNN gemm returned error");
}
};
template <typename a_dt, typename b_dt, typename c_dt>
class gemm_test_common_t : public ::testing::TestWithParam<test_params_t> {
protected:
void SetUp() override {
const auto &p = ::testing::TestWithParam<test_params_t>::GetParam();
SKIP_IF(get_test_engine_kind() == engine::kind::gpu,
"GPU GEMM not implemented.");
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
SKIP_IF(get_test_engine_kind() == engine::kind::cpu,
"SYCL CPU GEMM not implemented.");
#endif
bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0);
SKIP_IF(!zero_off && get_test_engine_kind() == engine::kind::cpu,
"CPU does not support non-zero offsets.");
SKIP_IF(unsupported_data_type(data_traits_t<a_dt>::data_type),
"Engine does not support this data type.");
bool is_f16
= (data_traits_t<a_dt>::data_type == memory::data_type::f16);
SKIP_IF(is_f16 && get_test_engine_kind() == engine::kind::cpu,
"CPU does not support f16 data type.");
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
SKIP_IF(get_test_engine_kind() == engine::kind::cpu,
"SYCL CPU GEMM not implemented.");
#endif
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
SKIP_IF(get_test_engine_kind() == engine::kind::gpu
&& (data_traits_t<a_dt>::data_type
== memory::data_type::u8
|| data_traits_t<a_dt>::data_type
== memory::data_type::s8),
"SYCL GPU int GEMM not implemented.");
SKIP_IF_CUDA(true, "Test not supported in CUDA backend");
#endif
#if DNNL_X64
bool is_bf16bf16f32 = true
&& data_traits_t<a_dt>::data_type == memory::data_type::bf16
&& data_traits_t<b_dt>::data_type == memory::data_type::bf16
&& data_traits_t<c_dt>::data_type == memory::data_type::f32;
SKIP_IF(is_bf16bf16f32 && get_test_engine_kind() == engine::kind::cpu
&& !dnnl::mayiuse(cpu_isa::avx512_core),
"Skip test for systems that do not support avx512_core.");
#endif
bool pack = (p.pack_params.pack_a || p.pack_params.pack_b);
SKIP_IF(!DNNL_X64 && pack,
"Packed GEMM does not support non-x64 CPUs.");
SKIP_IF((p.alpha != 1.f || p.igemm_params.oa() != 0
|| p.igemm_params.ob() != 0)
&& pack,
"Packed GEMM doesn't support alpha or non-zero offset{A,B}.");
SKIP_IF(data_traits_t<b_dt>::data_type == memory::data_type::u8
&& get_test_engine_kind() == engine::kind::cpu,
"CPU does not support s8u8s32 and u8u8s32 GEMM.");
SKIP_IF(data_traits_t<c_dt>::data_type == memory::data_type::bf16
&& get_test_engine_kind() == engine::kind::cpu,
"CPU does not support bf16bf16bf16 GEMM.");
catch_expected_failures(
[&]() { Test(); }, p.expect_to_fail, p.expected_status, false);
}
void Test() {
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
testing::scoped_tp_activation_t sta;
#endif
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
if (get_test_engine_kind() == engine::kind::gpu) {
const auto &p = ::testing::TestWithParam<test_params_t>::GetParam();
#if defined(TEST_DNNL_DPCPP_BUFFER)
// Test SYCL buffer interfaces
run_test_gemm_t<a_dt, b_dt, c_dt>::call(p);
#else
// Test SYCL USM interfaces
bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0);
SKIP_IF(!zero_off, "USM interfaces do not support offsets.");
run_test_gemm_t<a_dt, b_dt, c_dt>::call(p);
#endif
return;
}
#endif
const auto &p = ::testing::TestWithParam<test_params_t>::GetParam();
run_test_gemm_t<a_dt, b_dt, c_dt>::call(p);
}
};
} // namespace dnnl
#endif