graph: backend: dnnl: use dnnl_primitive_execute_without_tp_hook

This commit is contained in:
Gu, Yonghao
2025-06-04 03:14:46 +00:00
committed by Dmitrii Zarukin
parent 9e48052993
commit 4cb0e039cd
6 changed files with 89 additions and 9 deletions

View File

@ -22,6 +22,10 @@
#include <utility>
#include <vector>
#include "common/primitive_desc_iface.hpp"
#include "common/primitive_iface.hpp"
#include "common/stream.hpp"
#include "graph/interface/allocator.hpp"
#include "graph/interface/backend.hpp"
#include "graph/interface/shape_infer.hpp"
@ -695,6 +699,46 @@ dnnl::accumulation_mode str2accumulation_mode(
}
}
status_t dnnl_primitive_execute_without_tp_hook(const primitive &prim,
const stream &astream,
const std::unordered_map<int, memory> &exec_args) {
std::vector<dnnl_exec_arg_t> vec_args;
vec_args.reserve(exec_args.size());
for (const auto &a : exec_args)
vec_args.push_back({a.first, a.second.get(true)});
const primitive_iface_t *primitive_iface = prim.get();
stream_t *stream = astream.get();
int nargs = (int)vec_args.size();
const dnnl_exec_arg_t *c_args = vec_args.data();
bool ok = true && !dnnl::impl::utils::any_null(primitive_iface, stream)
&& primitive_iface->engine() == stream->engine()
&& IMPLICATION(nargs > 0, c_args != nullptr);
if (!ok) return status::invalid_arguments;
exec_args_t args;
status_t status = cvt_primitive_args(
primitive_iface->pd()->impl().get(), nargs, c_args, args);
if (status != status::success) return status;
exec_ctx_t ctx(stream, std::move(args));
#ifdef DNNL_ENABLE_STACK_CHECKER
stack_checker::stack_checker_t sc("dnnl_primitive_execute");
const auto *pd_iface = primitive_iface->pd();
bool is_wino
= std::string(pd_iface->info()).find("wino") != std::string::npos;
if (!is_wino) {
status = sc.check(
dnnl::impl::primitive_execute, primitive_iface, std::ref(ctx));
}
#else
status = dnnl::impl::primitive_execute(primitive_iface, ctx);
#endif
return status;
}
} // namespace dnnl_impl
} // namespace graph
} // namespace impl

View File

@ -146,6 +146,20 @@ dnnl::accumulation_mode str2accumulation_mode(
size_t generate_constant_md_hash(
size_t part_id, const std::vector<dnnl::memory::desc> &const_mds);
status_t dnnl_primitive_execute_without_tp_hook(const primitive &prim,
const stream &astream,
const std::unordered_map<int, memory> &exec_args);
#ifndef NDEBUG
#define BACKEND_DNNL_ENFORCE(condition, message) \
do { \
error::wrap_c_api((condition) ? dnnl_success : dnnl_invalid_arguments, \
(message)); \
} while (false)
#else
#define BACKEND_DNNL_ENFORCE(condition, message)
#endif
#define BACKEND_DNNL_CHECK(statement) \
do { \
status_t ret = (statement); \

View File

@ -206,6 +206,10 @@ status_t mqa_decomp_kernel_t<quantized, dt>::execute_impl(
UNUSED(scratchpad);
// prepare execution args and allocate real memory
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
auto tp = threadpool_utils::get_active_threadpool();
threadpool_utils::deactivate_threadpool();
#endif
prepare_sub_args(var_grantor, tid, block_size, res->mem_map);
// reorder0
@ -257,14 +261,20 @@ status_t mqa_decomp_kernel_t<quantized, dt>::execute_impl(
// in parallel region - these primitives should use single thread.
mqa_cfg_.sub_reorder0.execute(strm, res->sub_reorder0_args[tid]);
mqa_cfg_.sub_reorder1.execute(strm, res->sub_reorder1_args[tid]);
mqa_cfg_.sub_mm1_prim.execute(strm, res->sub_mm1_args[tid]);
dnnl_primitive_execute_without_tp_hook(
mqa_cfg_.sub_mm1_prim, strm, res->sub_mm1_args[tid]);
mqa_cfg_.sub_softmax_prim.execute(strm, res->sub_softmax_args[tid]);
dnnl_primitive_execute_without_tp_hook(
mqa_cfg_.sub_softmax_prim, strm, res->sub_softmax_args[tid]);
mqa_cfg_.sub_reorder2.execute(strm, res->sub_reorder2_args[tid]);
mqa_cfg_.sub_mm2_prim.execute(strm, res->sub_mm2_args[tid]);
dnnl_primitive_execute_without_tp_hook(
mqa_cfg_.sub_mm2_prim, strm, res->sub_mm2_args[tid]);
mqa_cfg_.sub_reorder3.execute(strm, res->sub_reorder3_args[tid]);
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
threadpool_utils::activate_threadpool(tp);
#endif
};
// TODO: remove this when primitive new API ready
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP

View File

@ -60,7 +60,7 @@ public:
void *handle = args.at(DNNL_ARG_SRC).get_data_handle();
args.at(DNNL_ARG_DST).set_data_handle(handle);
} else
reorder_.execute(astream, args);
dnnl_primitive_execute_without_tp_hook(reorder_, astream, args);
return status::success;
}
status_t reset_engine(const dnnl::engine &p_engine) {

View File

@ -206,6 +206,10 @@ status_t sdp_decomp_kernel_t<quantized, dt>::execute_impl(
const auto loop = [=](int tid, int nthr, dim_t bo, dim_t bi) {
UNUSED(scratchpad);
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
auto tp = threadpool_utils::get_active_threadpool();
threadpool_utils::deactivate_threadpool();
#endif
// prepare execution args and allocate real memory
prepare_sub_args(var_grantor, tid, block_size, res->mem_map);
@ -380,15 +384,22 @@ status_t sdp_decomp_kernel_t<quantized, dt>::execute_impl(
// in parallel region - these primitives should use single thread.
sdp_cfg_.sub_reorder0.execute(strm, res->sub_reorder0_args[tid]);
sdp_cfg_.sub_reorder1.execute(strm, res->sub_reorder1_args[tid]);
sdp_cfg_.sub_mm1_prim.execute(strm, res->sub_mm1_args[tid]);
dnnl_primitive_execute_without_tp_hook(
sdp_cfg_.sub_mm1_prim, strm, res->sub_mm1_args[tid]);
if (sdp_cfg_.has_select && !sdp_cfg_.select_fusiable)
sdp_cfg_.sub_select_prim.execute(strm, res->sub_select_args[tid]);
sdp_cfg_.sub_softmax_prim.execute(strm, res->sub_softmax_args[tid]);
dnnl_primitive_execute_without_tp_hook(
sdp_cfg_.sub_select_prim, strm, res->sub_select_args[tid]);
dnnl_primitive_execute_without_tp_hook(
sdp_cfg_.sub_softmax_prim, strm, res->sub_softmax_args[tid]);
sdp_cfg_.sub_reorder2.execute(strm, res->sub_reorder2_args[tid]);
sdp_cfg_.sub_mm2_prim.execute(strm, res->sub_mm2_args[tid]);
dnnl_primitive_execute_without_tp_hook(
sdp_cfg_.sub_mm2_prim, strm, res->sub_mm2_args[tid]);
sdp_cfg_.sub_reorder3.execute(strm, res->sub_reorder3_args[tid]);
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
threadpool_utils::activate_threadpool(tp);
#endif
};
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
tp_stream->before_exec_hook();

View File

@ -61,7 +61,8 @@ public:
void *handle = args.at(DNNL_ARG_SRC).get_data_handle();
args.at(DNNL_ARG_DST).set_data_handle(handle);
} else
reorder_prim_.execute(astream, args);
dnnl_primitive_execute_without_tp_hook(
reorder_prim_, astream, args);
return status::success;
}
status_t reset_engine(const dnnl::engine &p_engine) {