mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
graph: backend: dnnl: use dnnl_primitive_execute_without_tp_hook
This commit is contained in:
committed by
Dmitrii Zarukin
parent
9e48052993
commit
4cb0e039cd
@ -22,6 +22,10 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "common/primitive_desc_iface.hpp"
|
||||
#include "common/primitive_iface.hpp"
|
||||
#include "common/stream.hpp"
|
||||
|
||||
#include "graph/interface/allocator.hpp"
|
||||
#include "graph/interface/backend.hpp"
|
||||
#include "graph/interface/shape_infer.hpp"
|
||||
@ -695,6 +699,46 @@ dnnl::accumulation_mode str2accumulation_mode(
|
||||
}
|
||||
}
|
||||
|
||||
status_t dnnl_primitive_execute_without_tp_hook(const primitive &prim,
|
||||
const stream &astream,
|
||||
const std::unordered_map<int, memory> &exec_args) {
|
||||
std::vector<dnnl_exec_arg_t> vec_args;
|
||||
vec_args.reserve(exec_args.size());
|
||||
for (const auto &a : exec_args)
|
||||
vec_args.push_back({a.first, a.second.get(true)});
|
||||
|
||||
const primitive_iface_t *primitive_iface = prim.get();
|
||||
stream_t *stream = astream.get();
|
||||
int nargs = (int)vec_args.size();
|
||||
const dnnl_exec_arg_t *c_args = vec_args.data();
|
||||
|
||||
bool ok = true && !dnnl::impl::utils::any_null(primitive_iface, stream)
|
||||
&& primitive_iface->engine() == stream->engine()
|
||||
&& IMPLICATION(nargs > 0, c_args != nullptr);
|
||||
if (!ok) return status::invalid_arguments;
|
||||
|
||||
exec_args_t args;
|
||||
status_t status = cvt_primitive_args(
|
||||
primitive_iface->pd()->impl().get(), nargs, c_args, args);
|
||||
if (status != status::success) return status;
|
||||
|
||||
exec_ctx_t ctx(stream, std::move(args));
|
||||
#ifdef DNNL_ENABLE_STACK_CHECKER
|
||||
stack_checker::stack_checker_t sc("dnnl_primitive_execute");
|
||||
const auto *pd_iface = primitive_iface->pd();
|
||||
bool is_wino
|
||||
= std::string(pd_iface->info()).find("wino") != std::string::npos;
|
||||
if (!is_wino) {
|
||||
status = sc.check(
|
||||
dnnl::impl::primitive_execute, primitive_iface, std::ref(ctx));
|
||||
}
|
||||
#else
|
||||
status = dnnl::impl::primitive_execute(primitive_iface, ctx);
|
||||
#endif
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace dnnl_impl
|
||||
} // namespace graph
|
||||
} // namespace impl
|
||||
|
@ -146,6 +146,20 @@ dnnl::accumulation_mode str2accumulation_mode(
|
||||
size_t generate_constant_md_hash(
|
||||
size_t part_id, const std::vector<dnnl::memory::desc> &const_mds);
|
||||
|
||||
status_t dnnl_primitive_execute_without_tp_hook(const primitive &prim,
|
||||
const stream &astream,
|
||||
const std::unordered_map<int, memory> &exec_args);
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define BACKEND_DNNL_ENFORCE(condition, message) \
|
||||
do { \
|
||||
error::wrap_c_api((condition) ? dnnl_success : dnnl_invalid_arguments, \
|
||||
(message)); \
|
||||
} while (false)
|
||||
#else
|
||||
#define BACKEND_DNNL_ENFORCE(condition, message)
|
||||
#endif
|
||||
|
||||
#define BACKEND_DNNL_CHECK(statement) \
|
||||
do { \
|
||||
status_t ret = (statement); \
|
||||
|
@ -206,6 +206,10 @@ status_t mqa_decomp_kernel_t<quantized, dt>::execute_impl(
|
||||
UNUSED(scratchpad);
|
||||
|
||||
// prepare execution args and allocate real memory
|
||||
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||||
auto tp = threadpool_utils::get_active_threadpool();
|
||||
threadpool_utils::deactivate_threadpool();
|
||||
#endif
|
||||
prepare_sub_args(var_grantor, tid, block_size, res->mem_map);
|
||||
|
||||
// reorder0
|
||||
@ -257,14 +261,20 @@ status_t mqa_decomp_kernel_t<quantized, dt>::execute_impl(
|
||||
// in parallel region - these primitives should use single thread.
|
||||
mqa_cfg_.sub_reorder0.execute(strm, res->sub_reorder0_args[tid]);
|
||||
mqa_cfg_.sub_reorder1.execute(strm, res->sub_reorder1_args[tid]);
|
||||
mqa_cfg_.sub_mm1_prim.execute(strm, res->sub_mm1_args[tid]);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
mqa_cfg_.sub_mm1_prim, strm, res->sub_mm1_args[tid]);
|
||||
|
||||
mqa_cfg_.sub_softmax_prim.execute(strm, res->sub_softmax_args[tid]);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
mqa_cfg_.sub_softmax_prim, strm, res->sub_softmax_args[tid]);
|
||||
|
||||
mqa_cfg_.sub_reorder2.execute(strm, res->sub_reorder2_args[tid]);
|
||||
|
||||
mqa_cfg_.sub_mm2_prim.execute(strm, res->sub_mm2_args[tid]);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
mqa_cfg_.sub_mm2_prim, strm, res->sub_mm2_args[tid]);
|
||||
mqa_cfg_.sub_reorder3.execute(strm, res->sub_reorder3_args[tid]);
|
||||
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||||
threadpool_utils::activate_threadpool(tp);
|
||||
#endif
|
||||
};
|
||||
// TODO: remove this when primitive new API ready
|
||||
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP
|
||||
|
@ -60,7 +60,7 @@ public:
|
||||
void *handle = args.at(DNNL_ARG_SRC).get_data_handle();
|
||||
args.at(DNNL_ARG_DST).set_data_handle(handle);
|
||||
} else
|
||||
reorder_.execute(astream, args);
|
||||
dnnl_primitive_execute_without_tp_hook(reorder_, astream, args);
|
||||
return status::success;
|
||||
}
|
||||
status_t reset_engine(const dnnl::engine &p_engine) {
|
||||
|
@ -206,6 +206,10 @@ status_t sdp_decomp_kernel_t<quantized, dt>::execute_impl(
|
||||
const auto loop = [=](int tid, int nthr, dim_t bo, dim_t bi) {
|
||||
UNUSED(scratchpad);
|
||||
|
||||
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||||
auto tp = threadpool_utils::get_active_threadpool();
|
||||
threadpool_utils::deactivate_threadpool();
|
||||
#endif
|
||||
// prepare execution args and allocate real memory
|
||||
prepare_sub_args(var_grantor, tid, block_size, res->mem_map);
|
||||
|
||||
@ -380,15 +384,22 @@ status_t sdp_decomp_kernel_t<quantized, dt>::execute_impl(
|
||||
// in parallel region - these primitives should use single thread.
|
||||
sdp_cfg_.sub_reorder0.execute(strm, res->sub_reorder0_args[tid]);
|
||||
sdp_cfg_.sub_reorder1.execute(strm, res->sub_reorder1_args[tid]);
|
||||
sdp_cfg_.sub_mm1_prim.execute(strm, res->sub_mm1_args[tid]);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
sdp_cfg_.sub_mm1_prim, strm, res->sub_mm1_args[tid]);
|
||||
if (sdp_cfg_.has_select && !sdp_cfg_.select_fusiable)
|
||||
sdp_cfg_.sub_select_prim.execute(strm, res->sub_select_args[tid]);
|
||||
sdp_cfg_.sub_softmax_prim.execute(strm, res->sub_softmax_args[tid]);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
sdp_cfg_.sub_select_prim, strm, res->sub_select_args[tid]);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
sdp_cfg_.sub_softmax_prim, strm, res->sub_softmax_args[tid]);
|
||||
|
||||
sdp_cfg_.sub_reorder2.execute(strm, res->sub_reorder2_args[tid]);
|
||||
|
||||
sdp_cfg_.sub_mm2_prim.execute(strm, res->sub_mm2_args[tid]);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
sdp_cfg_.sub_mm2_prim, strm, res->sub_mm2_args[tid]);
|
||||
sdp_cfg_.sub_reorder3.execute(strm, res->sub_reorder3_args[tid]);
|
||||
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||||
threadpool_utils::activate_threadpool(tp);
|
||||
#endif
|
||||
};
|
||||
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||||
tp_stream->before_exec_hook();
|
||||
|
@ -61,7 +61,8 @@ public:
|
||||
void *handle = args.at(DNNL_ARG_SRC).get_data_handle();
|
||||
args.at(DNNL_ARG_DST).set_data_handle(handle);
|
||||
} else
|
||||
reorder_prim_.execute(astream, args);
|
||||
dnnl_primitive_execute_without_tp_hook(
|
||||
reorder_prim_, astream, args);
|
||||
return status::success;
|
||||
}
|
||||
status_t reset_engine(const dnnl::engine &p_engine) {
|
||||
|
Reference in New Issue
Block a user