mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
tests: prepare Eigen_threadpool implementation for new Eigen versions
This commit is contained in:
committed by
Dmitry Zarukin
parent
08beeba4c9
commit
105a7cd7fa
@ -1,5 +1,5 @@
|
|||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Copyright 2020-2021 Intel Corporation
|
# Copyright 2020-2025 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -33,10 +33,10 @@ if("${DNNL_CPU_THREADING_RUNTIME}" STREQUAL "THREADPOOL")
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if("${_DNNL_TEST_THREADPOOL_IMPL}" STREQUAL "EIGEN")
|
if("${_DNNL_TEST_THREADPOOL_IMPL}" STREQUAL "EIGEN")
|
||||||
find_package(Eigen3 REQUIRED 3.3 NO_MODULE)
|
find_package(Eigen3 3.3...<5.1 REQUIRED NO_MODULE)
|
||||||
if(Eigen3_FOUND)
|
if(Eigen3_FOUND)
|
||||||
list(APPEND EXTRA_STATIC_LIBS Eigen3::Eigen)
|
list(APPEND EXTRA_STATIC_LIBS Eigen3::Eigen)
|
||||||
message(STATUS "Threadpool testing: Eigen (${EIGEN3_ROOT_DIR})")
|
message(STATUS "Threadpool testing: Eigen (${PACKAGE_PREFIX_DIR})")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright 2020-2022 Intel Corporation
|
* Copyright 2020-2025 Intel Corporation
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -97,35 +97,23 @@ inline int read_num_threads_from_env() {
|
|||||||
|
|
||||||
#if defined(DNNL_TEST_THREADPOOL_USE_EIGEN)
|
#if defined(DNNL_TEST_THREADPOOL_USE_EIGEN)
|
||||||
|
|
||||||
#include <memory>
|
#define EIGEN_USE_THREADS
|
||||||
#include "Eigen/Core"
|
#include "unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "unsupported/Eigen/CXX11/ThreadPool"
|
#include "unsupported/Eigen/CXX11/ThreadPool"
|
||||||
|
|
||||||
#if EIGEN_WORLD_VERSION + 10 * EIGEN_MAJOR_VERSION < 33
|
#include <memory>
|
||||||
#define STR_(x) #x
|
|
||||||
#define STR(x) STR_(x)
|
|
||||||
#pragma message("EIGEN_WORLD_VERSION " STR(EIGEN_WORLD_VERSION))
|
|
||||||
#pragma message("EIGEN_MAJOR_VERSION " STR(EIGEN_MAJOR_VERSION))
|
|
||||||
#error Unsupported Eigen version (need 3.3.x or higher)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if EIGEN_MINOR_VERSION >= 90
|
|
||||||
using EigenThreadPool = Eigen::ThreadPool;
|
|
||||||
#else
|
|
||||||
using EigenThreadPool = Eigen::NonBlockingThreadPool;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace dnnl {
|
namespace dnnl {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
|
|
||||||
class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
|
class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<EigenThreadPool> tp_;
|
std::unique_ptr<Eigen::ThreadPool> tp_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit threadpool_t(int num_threads = 0) {
|
explicit threadpool_t(int num_threads = 0) {
|
||||||
if (num_threads <= 0) num_threads = read_num_threads_from_env();
|
if (num_threads <= 0) num_threads = read_num_threads_from_env();
|
||||||
tp_.reset(new EigenThreadPool(num_threads));
|
tp_.reset(new Eigen::ThreadPool(num_threads));
|
||||||
}
|
}
|
||||||
int get_num_threads() const override { return tp_->NumThreads(); }
|
int get_num_threads() const override { return tp_->NumThreads(); }
|
||||||
bool get_in_parallel() const override {
|
bool get_in_parallel() const override {
|
||||||
|
Reference in New Issue
Block a user