mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
209 lines
8.1 KiB
C++
209 lines
8.1 KiB
C++
/*******************************************************************************
|
|
* Copyright 2020-2023 Intel Corporation
|
|
* Copyright 2025 Arm Ltd. and affiliates
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*******************************************************************************/
|
|
|
|
#include "dnnl_test_common.hpp"
|
|
|
|
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
|
|
#include "tests/test_isa_common.hpp"
|
|
#endif
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
#include "oneapi/dnnl/dnnl.hpp"
|
|
|
|
namespace dnnl {
|
|
|
|
using dt = memory::data_type;
|
|
using tag = memory::format_tag;
|
|
using md = memory::desc;
|
|
|
|
using ip_fwd = inner_product_forward;
|
|
using ip_bwd_d = inner_product_backward_data;
|
|
using ip_bwd_w = inner_product_backward_weights;
|
|
|
|
class ip_formats_test_t : public ::testing::Test {
|
|
public:
|
|
engine e;
|
|
|
|
protected:
|
|
void SetUp() override {
|
|
e = get_test_engine();
|
|
SKIP_IF(get_test_engine_kind() == engine::kind::gpu,
|
|
"GPU takes a lot of time to complete this test.");
|
|
|
|
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
|
|
const bool supports_bf16
|
|
= dnnl::impl::cpu::platform::has_data_type_support(dnnl_bf16);
|
|
const bool supports_f16
|
|
= dnnl::impl::cpu::platform::has_data_type_support(dnnl_f16);
|
|
#else
|
|
const bool supports_bf16 = false;
|
|
const bool supports_f16 = false;
|
|
#endif
|
|
|
|
bool is_cpu = get_test_engine_kind() == engine::kind::cpu;
|
|
|
|
memory::dims SP1D = {2};
|
|
memory::dims SP2D = {2, 2};
|
|
memory::dims SP3D = {2, 2, 2};
|
|
memory::dims SP4D = {2, 2, 2, 2};
|
|
memory::dims SP5D = {2, 2, 2, 2, 2};
|
|
memory::dims SP6D = {2, 2, 2, 2, 2, 2};
|
|
memory::dims SP7D = {2, 2, 2, 2, 2, 2, 2};
|
|
memory::dims SP8D = {2, 2, 2, 2, 2, 2, 2, 2};
|
|
memory::dims SP9D = {2, 2, 2, 2, 2, 2, 2, 2, 2};
|
|
memory::dims SP10D = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
|
|
memory::dims SP11D = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
|
|
memory::dims SP12D = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
|
|
std::vector<memory::dims> v_dims = {SP1D, SP2D, SP3D, SP4D, SP5D, SP6D,
|
|
SP7D, SP8D, SP9D, SP10D, SP11D, SP12D};
|
|
std::vector<memory::dims> unsup_dims
|
|
= {SP1D, SP6D, SP7D, SP8D, SP9D, SP10D, SP11D, SP12D};
|
|
|
|
unsigned start_tag = static_cast<unsigned>(tag::any);
|
|
unsigned end_tag = static_cast<unsigned>(tag::format_tag_last);
|
|
|
|
md src_md, wei_md, dst_md;
|
|
tag src_tag {tag::any}, wei_tag {tag::any}, dst_tag {tag::any};
|
|
std::vector<std::vector<dt>> cfg {
|
|
{dt::f32, dt::f32, dt::f32},
|
|
{dt::bf16, dt::bf16, dt::bf16},
|
|
{dt::u8, dt::s8, dt::u8},
|
|
};
|
|
|
|
for (const auto &i_cfg : cfg) {
|
|
if (i_cfg[0] == dt::f16 && is_cpu) continue;
|
|
|
|
bool cfg_has_bf16 = i_cfg[0] == dt::bf16 || i_cfg[1] == dt::bf16
|
|
|| i_cfg[2] == dt::bf16;
|
|
if (cfg_has_bf16 && !supports_bf16) continue;
|
|
bool cfg_has_f16 = i_cfg[0] == dt::f16 || i_cfg[1] == dt::f16
|
|
|| i_cfg[2] == dt::f16;
|
|
if (cfg_has_f16 && !supports_f16) continue;
|
|
|
|
for (unsigned stag = start_tag; stag < end_tag; stag++) {
|
|
src_tag = static_cast<tag>(stag);
|
|
|
|
// ip does not support 1D and 6D-12D cases
|
|
bool skip_tag = false;
|
|
for (const auto &i_dims : unsup_dims) {
|
|
src_md = md(i_dims, i_cfg[0], src_tag, true);
|
|
if (src_md) {
|
|
skip_tag = true;
|
|
break;
|
|
}
|
|
}
|
|
if (skip_tag) continue;
|
|
|
|
memory::dims cur_dims {};
|
|
for (const auto &i_dims : v_dims) {
|
|
src_md = md(i_dims, i_cfg[0], src_tag, true);
|
|
if (src_md) {
|
|
cur_dims = i_dims;
|
|
break;
|
|
}
|
|
}
|
|
ASSERT_TRUE(src_md);
|
|
|
|
for (unsigned wtag = start_tag; wtag < end_tag; wtag++) {
|
|
wei_tag = static_cast<tag>(wtag);
|
|
wei_md = md(cur_dims, i_cfg[1], wei_tag, true);
|
|
if (!wei_md) continue;
|
|
|
|
dst_md = md(SP2D, i_cfg[2], dst_tag);
|
|
ASSERT_TRUE(dst_md);
|
|
|
|
catch_expected_failures(
|
|
[&]() {
|
|
TestFormat(src_md, wei_md, dst_md, i_cfg);
|
|
},
|
|
false, dnnl_success);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void TestFormat(const md &src_md, const md &wei_md, const md &dst_md,
|
|
const std::vector<dt> &i_cfg) const {
|
|
ip_fwd::primitive_desc ip_fwd_pd(e, prop_kind::forward_training, src_md,
|
|
wei_md, dst_md, {}, true);
|
|
if (ip_fwd_pd) {
|
|
// Filter-out reference implementation since it doesn't have any
|
|
// format restrictions and too slow.
|
|
std::string pd_name = ip_fwd_pd.impl_info_str();
|
|
if (pd_name.find("ref") != std::string::npos) return;
|
|
|
|
auto ip_fwd_prim = ip_fwd(ip_fwd_pd);
|
|
auto strm = make_stream(ip_fwd_pd.get_engine());
|
|
auto src = test::make_memory(ip_fwd_pd.src_desc(), e);
|
|
auto wei = test::make_memory(ip_fwd_pd.weights_desc(), e);
|
|
auto dst = test::make_memory(ip_fwd_pd.dst_desc(), e);
|
|
ip_fwd_prim.execute(strm,
|
|
{{DNNL_ARG_SRC, src}, {DNNL_ARG_WEIGHTS, wei},
|
|
{DNNL_ARG_DST, dst}});
|
|
strm.wait();
|
|
}
|
|
|
|
// no sense to test backward if forward was not created
|
|
if (!ip_fwd_pd) return;
|
|
// int8 is not supported on backward;
|
|
if (i_cfg[1] == dt::s8) return;
|
|
|
|
ip_bwd_d::primitive_desc ip_bwd_d_pd(
|
|
e, src_md, wei_md, dst_md, ip_fwd_pd, {}, true);
|
|
if (ip_bwd_d_pd) {
|
|
// Filter-out reference implementation since it doesn't have any
|
|
// format restrictions and too slow.
|
|
std::string pd_name = ip_bwd_d_pd.impl_info_str();
|
|
if (pd_name.find("ref") != std::string::npos) return;
|
|
|
|
auto ip_bwd_d_prim = ip_bwd_d(ip_bwd_d_pd);
|
|
auto strm = make_stream(ip_bwd_d_pd.get_engine());
|
|
auto d_src = memory(ip_bwd_d_pd.diff_src_desc(), e);
|
|
auto d_wei = memory(ip_bwd_d_pd.weights_desc(), e);
|
|
auto d_dst = memory(ip_bwd_d_pd.diff_dst_desc(), e);
|
|
ip_bwd_d_prim.execute(strm,
|
|
{{DNNL_ARG_DIFF_SRC, d_src}, {DNNL_ARG_WEIGHTS, d_wei},
|
|
{DNNL_ARG_DIFF_DST, d_dst}});
|
|
strm.wait();
|
|
}
|
|
|
|
ip_bwd_w::primitive_desc ip_bwd_w_pd(
|
|
e, src_md, wei_md, dst_md, ip_fwd_pd, {}, true);
|
|
if (ip_bwd_w_pd) {
|
|
// Filter-out reference implementation since it doesn't have any
|
|
// format restrictions and too slow.
|
|
std::string pd_name = ip_bwd_w_pd.impl_info_str();
|
|
if (pd_name.find("ref") != std::string::npos) return;
|
|
|
|
auto ip_bwd_w_prim = ip_bwd_w(ip_bwd_w_pd);
|
|
auto strm = make_stream(ip_bwd_w_pd.get_engine());
|
|
auto src = memory(ip_bwd_w_pd.src_desc(), e);
|
|
auto d_wei = memory(ip_bwd_w_pd.diff_weights_desc(), e);
|
|
auto d_dst = memory(ip_bwd_w_pd.diff_dst_desc(), e);
|
|
ip_bwd_w_prim.execute(strm,
|
|
{{DNNL_ARG_SRC, src}, {DNNL_ARG_DIFF_WEIGHTS, d_wei},
|
|
{DNNL_ARG_DIFF_DST, d_dst}});
|
|
strm.wait();
|
|
}
|
|
}
|
|
};
|
|
|
|
TEST_F(ip_formats_test_t, TestChecksAllFormats) {}
|
|
|
|
} // namespace dnnl
|