tests: benchdnn: graph: remove useless data type alignment

This commit is contained in:
Wang, Zhitao
2023-08-18 08:15:27 +00:00
committed by thuang6
parent 35ed364d05
commit d94f87c9a6

View File

@ -26,18 +26,6 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
const auto &prim_dt = mem.dt();
const auto &graph_dt = static_cast<dnnl_data_type_t>(lt.get_data_type());
// For int8 cases, as graph driver will modify the data type of leading
// ops to u8/s8 in the reference path and use corresponding drivers to
// generate data, special handling is needed. If it's found that data
// type in ref path is u8/s8, it will be used.
//
// The reason why not always using primitive data type is that the driver
// rewrites data type in graph path for bf16 case handling. So we prefer
// data type in graph, and for int8 cases, that from ref path will be used.
//
dnnl_data_type_t c_data_type
= prim_dt == dnnl_s8 || prim_dt == dnnl_u8 ? prim_dt : graph_dt;
// Get memory tag of primitive memory
int ndims = mem.ndims();
dims_t strides(mem.strides(), mem.strides() + ndims);
@ -58,7 +46,7 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
// otherwise use shape & tag from ref path side
// Create memory for graph path
const auto data_type = static_cast<dnnl::memory::data_type>(c_data_type);
const auto data_type = static_cast<dnnl::memory::data_type>(graph_dt);
if (is_op_input) {
if (graph_dims_.empty()) graph_dims_.push_back(1);
if (graph_strides_.empty()) graph_strides_.push_back(1);
@ -74,9 +62,9 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
std::memcpy(graph_data_handle, prim_data_handle, graph_mem.size());
};
if (prim_dt != c_data_type) {
if (prim_dt != graph_dt) {
dnn_mem_t c_mem(
ndims, mem.dims(), c_data_type, mtag, ::get_test_engine());
ndims, mem.dims(), graph_dt, mtag, ::get_test_engine());
c_mem.reorder(mem);
prim_to_graph_memcpy(mem_, c_mem);
} else {
@ -87,7 +75,7 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
dnnl::memory::desc md(graph_dims_, data_type, graph_strides_);
mem_ = dnn_mem_t(md.get(), ::get_test_engine());
} else {
mem_ = dnn_mem_t(mem.md_, c_data_type, mtag, ::get_test_engine());
mem_ = dnn_mem_t(mem.md_, graph_dt, mtag, ::get_test_engine());
}
}
}