Files
oneDNN/tests/gtests/test_ip_formats.cpp
Siddhartha Menon bcc0ca0084 tests: fix clang-tidy failures (#4082)
Signed-off-by: Siddhartha Menon <siddhartha.menon@arm.com>
2025-10-08 10:29:27 +01:00

209 lines
8.1 KiB
C++

/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2025 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "dnnl_test_common.hpp"
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
#include "tests/test_isa_common.hpp"
#endif
#include "gtest/gtest.h"
#include "oneapi/dnnl/dnnl.hpp"
namespace dnnl {
using dt = memory::data_type;
using tag = memory::format_tag;
using md = memory::desc;
using ip_fwd = inner_product_forward;
using ip_bwd_d = inner_product_backward_data;
using ip_bwd_w = inner_product_backward_weights;
class ip_formats_test_t : public ::testing::Test {
public:
engine e;
protected:
void SetUp() override {
e = get_test_engine();
SKIP_IF(get_test_engine_kind() == engine::kind::gpu,
"GPU takes a lot of time to complete this test.");
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
const bool supports_bf16
= dnnl::impl::cpu::platform::has_data_type_support(dnnl_bf16);
const bool supports_f16
= dnnl::impl::cpu::platform::has_data_type_support(dnnl_f16);
#else
const bool supports_bf16 = false;
const bool supports_f16 = false;
#endif
bool is_cpu = get_test_engine_kind() == engine::kind::cpu;
memory::dims SP1D = {2};
memory::dims SP2D = {2, 2};
memory::dims SP3D = {2, 2, 2};
memory::dims SP4D = {2, 2, 2, 2};
memory::dims SP5D = {2, 2, 2, 2, 2};
memory::dims SP6D = {2, 2, 2, 2, 2, 2};
memory::dims SP7D = {2, 2, 2, 2, 2, 2, 2};
memory::dims SP8D = {2, 2, 2, 2, 2, 2, 2, 2};
memory::dims SP9D = {2, 2, 2, 2, 2, 2, 2, 2, 2};
memory::dims SP10D = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
memory::dims SP11D = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
memory::dims SP12D = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
std::vector<memory::dims> v_dims = {SP1D, SP2D, SP3D, SP4D, SP5D, SP6D,
SP7D, SP8D, SP9D, SP10D, SP11D, SP12D};
std::vector<memory::dims> unsup_dims
= {SP1D, SP6D, SP7D, SP8D, SP9D, SP10D, SP11D, SP12D};
unsigned start_tag = static_cast<unsigned>(tag::any);
unsigned end_tag = static_cast<unsigned>(tag::format_tag_last);
md src_md, wei_md, dst_md;
tag src_tag {tag::any}, wei_tag {tag::any}, dst_tag {tag::any};
std::vector<std::vector<dt>> cfg {
{dt::f32, dt::f32, dt::f32},
{dt::bf16, dt::bf16, dt::bf16},
{dt::u8, dt::s8, dt::u8},
};
for (const auto &i_cfg : cfg) {
if (i_cfg[0] == dt::f16 && is_cpu) continue;
bool cfg_has_bf16 = i_cfg[0] == dt::bf16 || i_cfg[1] == dt::bf16
|| i_cfg[2] == dt::bf16;
if (cfg_has_bf16 && !supports_bf16) continue;
bool cfg_has_f16 = i_cfg[0] == dt::f16 || i_cfg[1] == dt::f16
|| i_cfg[2] == dt::f16;
if (cfg_has_f16 && !supports_f16) continue;
for (unsigned stag = start_tag; stag < end_tag; stag++) {
src_tag = static_cast<tag>(stag);
// ip does not support 1D and 6D-12D cases
bool skip_tag = false;
for (const auto &i_dims : unsup_dims) {
src_md = md(i_dims, i_cfg[0], src_tag, true);
if (src_md) {
skip_tag = true;
break;
}
}
if (skip_tag) continue;
memory::dims cur_dims {};
for (const auto &i_dims : v_dims) {
src_md = md(i_dims, i_cfg[0], src_tag, true);
if (src_md) {
cur_dims = i_dims;
break;
}
}
ASSERT_TRUE(src_md);
for (unsigned wtag = start_tag; wtag < end_tag; wtag++) {
wei_tag = static_cast<tag>(wtag);
wei_md = md(cur_dims, i_cfg[1], wei_tag, true);
if (!wei_md) continue;
dst_md = md(SP2D, i_cfg[2], dst_tag);
ASSERT_TRUE(dst_md);
catch_expected_failures(
[&]() {
TestFormat(src_md, wei_md, dst_md, i_cfg);
},
false, dnnl_success);
}
}
}
}
void TestFormat(const md &src_md, const md &wei_md, const md &dst_md,
const std::vector<dt> &i_cfg) const {
ip_fwd::primitive_desc ip_fwd_pd(e, prop_kind::forward_training, src_md,
wei_md, dst_md, {}, true);
if (ip_fwd_pd) {
// Filter-out reference implementation since it doesn't have any
// format restrictions and too slow.
std::string pd_name = ip_fwd_pd.impl_info_str();
if (pd_name.find("ref") != std::string::npos) return;
auto ip_fwd_prim = ip_fwd(ip_fwd_pd);
auto strm = make_stream(ip_fwd_pd.get_engine());
auto src = test::make_memory(ip_fwd_pd.src_desc(), e);
auto wei = test::make_memory(ip_fwd_pd.weights_desc(), e);
auto dst = test::make_memory(ip_fwd_pd.dst_desc(), e);
ip_fwd_prim.execute(strm,
{{DNNL_ARG_SRC, src}, {DNNL_ARG_WEIGHTS, wei},
{DNNL_ARG_DST, dst}});
strm.wait();
}
// no sense to test backward if forward was not created
if (!ip_fwd_pd) return;
// int8 is not supported on backward;
if (i_cfg[1] == dt::s8) return;
ip_bwd_d::primitive_desc ip_bwd_d_pd(
e, src_md, wei_md, dst_md, ip_fwd_pd, {}, true);
if (ip_bwd_d_pd) {
// Filter-out reference implementation since it doesn't have any
// format restrictions and too slow.
std::string pd_name = ip_bwd_d_pd.impl_info_str();
if (pd_name.find("ref") != std::string::npos) return;
auto ip_bwd_d_prim = ip_bwd_d(ip_bwd_d_pd);
auto strm = make_stream(ip_bwd_d_pd.get_engine());
auto d_src = memory(ip_bwd_d_pd.diff_src_desc(), e);
auto d_wei = memory(ip_bwd_d_pd.weights_desc(), e);
auto d_dst = memory(ip_bwd_d_pd.diff_dst_desc(), e);
ip_bwd_d_prim.execute(strm,
{{DNNL_ARG_DIFF_SRC, d_src}, {DNNL_ARG_WEIGHTS, d_wei},
{DNNL_ARG_DIFF_DST, d_dst}});
strm.wait();
}
ip_bwd_w::primitive_desc ip_bwd_w_pd(
e, src_md, wei_md, dst_md, ip_fwd_pd, {}, true);
if (ip_bwd_w_pd) {
// Filter-out reference implementation since it doesn't have any
// format restrictions and too slow.
std::string pd_name = ip_bwd_w_pd.impl_info_str();
if (pd_name.find("ref") != std::string::npos) return;
auto ip_bwd_w_prim = ip_bwd_w(ip_bwd_w_pd);
auto strm = make_stream(ip_bwd_w_pd.get_engine());
auto src = memory(ip_bwd_w_pd.src_desc(), e);
auto d_wei = memory(ip_bwd_w_pd.diff_weights_desc(), e);
auto d_dst = memory(ip_bwd_w_pd.diff_dst_desc(), e);
ip_bwd_w_prim.execute(strm,
{{DNNL_ARG_SRC, src}, {DNNL_ARG_DIFF_WEIGHTS, d_wei},
{DNNL_ARG_DIFF_DST, d_dst}});
strm.wait();
}
}
};
TEST_F(ip_formats_test_t, TestChecksAllFormats) {}
} // namespace dnnl