mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aten] Pass std::function<> to thread_pool by value, instead of const ref. (#37681)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37681 By passing by value, we can std::move, and avoid unnecessarily copying args that are part of any std::function/lambda state (e.g. in the jit interpreter, there is a std::vector<> stack passed in the InterpreterContinuation) This makes the api also consistent with e.g. folly and best practices. Added a minor at::launch() benchmark to test/cpp/, the difference is mostly noticeable when copying the std::function<> internal args is non-trivial. Benchmarks pre/post (min over ~5 runs) NoData: 5.81 us -> 5.63 us (-3.2%) WithData(0): 6.67 us -> 5.88 us (-11.8%) WithData(4): 6.98 us -> 6.51 us (-6.7%) WithData(256): 9.44 us -> 7.89 (-16.5%) ghstack-source-id: 103322321 Test Plan: - perf: buck run mode/opt caffe2/test/cpp/api:parallel_benchmark pre/post - correctness buck test mode/dev-nosan caffe2/test/... Reviewed By: dzhulgakov Differential Revision: D21355148 fbshipit-source-id: 3567e730845106f1991091e4a892d093e00571c3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d7ccb4b392
commit
468a9d448e
@ -78,9 +78,9 @@ void launch(std::function<void()> func) {
|
||||
);
|
||||
|
||||
#if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
|
||||
intraop_launch(fn);
|
||||
intraop_launch(std::move(fn));
|
||||
#else
|
||||
get_pool().run(fn);
|
||||
get_pool().run(std::move(fn));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ bool ThreadPool::inThreadPool() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
void ThreadPool::run(const std::function<void()>& func) {
|
||||
void ThreadPool::run(std::function<void()> func) {
|
||||
if (threads_.size() == 0) {
|
||||
throw std::runtime_error("No threads to run a task");
|
||||
}
|
||||
@ -63,7 +63,7 @@ void ThreadPool::run(const std::function<void()>& func) {
|
||||
|
||||
// Set task and signal condition variable so that a worker thread will
|
||||
// wake up and use the task.
|
||||
tasks_.push(task_element_t(func));
|
||||
tasks_.emplace(std::move(func));
|
||||
complete_ = false;
|
||||
condition_.notify_one();
|
||||
}
|
||||
@ -94,7 +94,7 @@ void ThreadPool::main_loop(std::size_t index) {
|
||||
// useful in the event that the function contains
|
||||
// shared_ptr arguments bound via bind.
|
||||
{
|
||||
auto tasks = tasks_.front();
|
||||
task_element_t tasks = std::move(tasks_.front());
|
||||
tasks_.pop();
|
||||
// Decrement count, indicating thread is no longer available.
|
||||
--available_;
|
||||
|
||||
@ -17,7 +17,7 @@ namespace c10 {
|
||||
// TODO: move this to C10 and make it C10_API
|
||||
class C10_API TaskThreadPoolBase {
|
||||
public:
|
||||
virtual void run(const std::function<void()>& func) = 0;
|
||||
virtual void run(std::function<void()> func) = 0;
|
||||
|
||||
virtual size_t size() const = 0;
|
||||
|
||||
@ -49,10 +49,10 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
const std::function<void()> no_id;
|
||||
const std::function<void(std::size_t)> with_id;
|
||||
|
||||
explicit task_element_t(const std::function<void()>& f)
|
||||
: run_with_id(false), no_id(f), with_id(nullptr) {}
|
||||
explicit task_element_t(const std::function<void(std::size_t)>& f)
|
||||
: run_with_id(true), no_id(nullptr), with_id(f) {}
|
||||
explicit task_element_t(std::function<void()> f)
|
||||
: run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
|
||||
explicit task_element_t(std::function<void(std::size_t)> f)
|
||||
: run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
|
||||
};
|
||||
|
||||
std::queue<task_element_t> tasks_;
|
||||
@ -82,7 +82,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
|
||||
bool inThreadPool() const override;
|
||||
|
||||
void run(const std::function<void()>& func) override;
|
||||
void run(std::function<void()> func) override;
|
||||
|
||||
template <typename Task>
|
||||
void runTaskWithID(Task task) {
|
||||
@ -90,8 +90,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
|
||||
|
||||
// Set task and signal condition variable so that a worker thread will
|
||||
// wake up and use the task.
|
||||
tasks_.push(
|
||||
task_element_t(static_cast<std::function<void(std::size_t)>>(task)));
|
||||
tasks_.emplace(static_cast<std::function<void(std::size_t)>>(task));
|
||||
complete_ = false;
|
||||
condition_.notify_one();
|
||||
}
|
||||
|
||||
@ -73,3 +73,7 @@ if(INSTALL_TEST)
|
||||
install(FILES $<TARGET_PDB_FILE:test_api> DESTINATION bin OPTIONAL)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_executable(parallel_benchmark ${TORCH_API_TEST_DIR}/parallel_benchmark.cpp)
|
||||
target_include_directories(parallel_benchmark PRIVATE ${ATen_CPU_INCLUDE})
|
||||
target_link_libraries(parallel_benchmark PRIVATE torch)
|
||||
|
||||
88
test/cpp/api/parallel_benchmark.cpp
Normal file
88
test/cpp/api/parallel_benchmark.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
#include <torch/torch.h>
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
|
||||
class Baton {
|
||||
public:
|
||||
void post() {
|
||||
std::unique_lock<std::mutex> l(lock_);
|
||||
done_ = true;
|
||||
cv_.notify_all();
|
||||
}
|
||||
void wait() {
|
||||
std::unique_lock<std::mutex> l(lock_);
|
||||
while (!done_) {
|
||||
cv_.wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
std::condition_variable cv_;
|
||||
bool done_{false};
|
||||
};
|
||||
|
||||
void AtLaunch_Base(int32_t numIters) {
|
||||
struct Helper {
|
||||
explicit Helper(int32_t lim) : limit_(lim) {}
|
||||
void operator()() {
|
||||
if (++val_ == limit_) {
|
||||
done.post();
|
||||
} else {
|
||||
at::launch([this]() { (*this)(); });
|
||||
}
|
||||
}
|
||||
int val_{0};
|
||||
int limit_;
|
||||
Baton done;
|
||||
};
|
||||
Helper h(numIters);
|
||||
auto start = std::chrono::system_clock::now();
|
||||
h();
|
||||
h.done.wait();
|
||||
std::cout << "NoData "
|
||||
<< static_cast<double>(
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::system_clock::now() - start)
|
||||
.count()) /
|
||||
static_cast<double>(numIters)
|
||||
<< " usec/each\n";
|
||||
}
|
||||
|
||||
void AtLaunch_WithData(int32_t numIters, int32_t vecSize) {
|
||||
struct Helper {
|
||||
explicit Helper(int32_t lim) : limit_(lim) {}
|
||||
void operator()(std::vector<int32_t> v) {
|
||||
if (++val_ == limit_) {
|
||||
done.post();
|
||||
} else {
|
||||
at::launch([this, v = std::move(v)]() { (*this)(v); });
|
||||
}
|
||||
}
|
||||
int val_{0};
|
||||
int limit_;
|
||||
Baton done;
|
||||
};
|
||||
Helper h(numIters);
|
||||
std::vector<int32_t> v(vecSize, 0);
|
||||
auto start = std::chrono::system_clock::now();
|
||||
h(v);
|
||||
h.done.wait();
|
||||
std::cout << "WithData(" << vecSize << "): "
|
||||
<< static_cast<double>(
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::system_clock::now() - start)
|
||||
.count()) /
|
||||
static_cast<double>(numIters)
|
||||
<< " usec/each\n";
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
int32_t N = 1000000;
|
||||
AtLaunch_Base(N);
|
||||
AtLaunch_WithData(N, 0);
|
||||
AtLaunch_WithData(N, 4);
|
||||
AtLaunch_WithData(N, 256);
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user