[aten] Pass std::function<> to thread_pool by value, instead of const ref. (#37681)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37681

By passing by value, we can std::move, and avoid unnecessarily copying
args that are part of any std::function/lambda state (e.g. in the jit
interpreter, there is a std::vector<> stack passed in the
InterpreterContinuation)

This makes the api also consistent with e.g. folly and best practices.
Added a minor at::launch() benchmark to test/cpp/, the difference is
mostly noticeable when copying the std::function<> internal args is
non-trivial.

Benchmarks pre/post (min over ~5 runs)
NoData: 5.81 us -> 5.63 us (-3.2%)
WithData(0): 6.67 us -> 5.88 us (-11.8%)
WithData(4): 6.98 us -> 6.51 us (-6.7%)
WithData(256): 9.44 us -> 7.89 (-16.5%)

ghstack-source-id: 103322321

Test Plan:
- perf: buck run mode/opt caffe2/test/cpp/api:parallel_benchmark pre/post
  - correctness buck test mode/dev-nosan caffe2/test/...

Reviewed By: dzhulgakov

Differential Revision: D21355148

fbshipit-source-id: 3567e730845106f1991091e4a892d093e00571c3
This commit is contained in:
Jeremy Lilley
2020-05-05 08:39:15 -07:00
committed by Facebook GitHub Bot
parent d7ccb4b392
commit 468a9d448e
5 changed files with 104 additions and 13 deletions

View File

@ -78,9 +78,9 @@ void launch(std::function<void()> func) {
);
#if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
intraop_launch(fn);
intraop_launch(std::move(fn));
#else
get_pool().run(fn);
get_pool().run(std::move(fn));
#endif
}

View File

@ -55,7 +55,7 @@ bool ThreadPool::inThreadPool() const {
return false;
}
void ThreadPool::run(const std::function<void()>& func) {
void ThreadPool::run(std::function<void()> func) {
if (threads_.size() == 0) {
throw std::runtime_error("No threads to run a task");
}
@ -63,7 +63,7 @@ void ThreadPool::run(const std::function<void()>& func) {
// Set task and signal condition variable so that a worker thread will
// wake up and use the task.
tasks_.push(task_element_t(func));
tasks_.emplace(std::move(func));
complete_ = false;
condition_.notify_one();
}
@ -94,7 +94,7 @@ void ThreadPool::main_loop(std::size_t index) {
// useful in the event that the function contains
// shared_ptr arguments bound via bind.
{
auto tasks = tasks_.front();
task_element_t tasks = std::move(tasks_.front());
tasks_.pop();
// Decrement count, indicating thread is no longer available.
--available_;

View File

@ -17,7 +17,7 @@ namespace c10 {
// TODO: move this to C10 and make it C10_API
class C10_API TaskThreadPoolBase {
public:
virtual void run(const std::function<void()>& func) = 0;
virtual void run(std::function<void()> func) = 0;
virtual size_t size() const = 0;
@ -49,10 +49,10 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
const std::function<void()> no_id;
const std::function<void(std::size_t)> with_id;
explicit task_element_t(const std::function<void()>& f)
: run_with_id(false), no_id(f), with_id(nullptr) {}
explicit task_element_t(const std::function<void(std::size_t)>& f)
: run_with_id(true), no_id(nullptr), with_id(f) {}
explicit task_element_t(std::function<void()> f)
: run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
explicit task_element_t(std::function<void(std::size_t)> f)
: run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
};
std::queue<task_element_t> tasks_;
@ -82,7 +82,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
bool inThreadPool() const override;
void run(const std::function<void()>& func) override;
void run(std::function<void()> func) override;
template <typename Task>
void runTaskWithID(Task task) {
@ -90,8 +90,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase {
// Set task and signal condition variable so that a worker thread will
// wake up and use the task.
tasks_.push(
task_element_t(static_cast<std::function<void(std::size_t)>>(task)));
tasks_.emplace(static_cast<std::function<void(std::size_t)>>(task));
complete_ = false;
condition_.notify_one();
}

View File

@ -73,3 +73,7 @@ if(INSTALL_TEST)
install(FILES $<TARGET_PDB_FILE:test_api> DESTINATION bin OPTIONAL)
endif()
endif()
add_executable(parallel_benchmark ${TORCH_API_TEST_DIR}/parallel_benchmark.cpp)
target_include_directories(parallel_benchmark PRIVATE ${ATen_CPU_INCLUDE})
target_link_libraries(parallel_benchmark PRIVATE torch)

View File

@ -0,0 +1,88 @@
#include <torch/torch.h>
#include <chrono>
#include <condition_variable>
#include <mutex>
class Baton {
public:
void post() {
std::unique_lock<std::mutex> l(lock_);
done_ = true;
cv_.notify_all();
}
void wait() {
std::unique_lock<std::mutex> l(lock_);
while (!done_) {
cv_.wait(l);
}
}
private:
std::mutex lock_;
std::condition_variable cv_;
bool done_{false};
};
void AtLaunch_Base(int32_t numIters) {
struct Helper {
explicit Helper(int32_t lim) : limit_(lim) {}
void operator()() {
if (++val_ == limit_) {
done.post();
} else {
at::launch([this]() { (*this)(); });
}
}
int val_{0};
int limit_;
Baton done;
};
Helper h(numIters);
auto start = std::chrono::system_clock::now();
h();
h.done.wait();
std::cout << "NoData "
<< static_cast<double>(
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now() - start)
.count()) /
static_cast<double>(numIters)
<< " usec/each\n";
}
void AtLaunch_WithData(int32_t numIters, int32_t vecSize) {
struct Helper {
explicit Helper(int32_t lim) : limit_(lim) {}
void operator()(std::vector<int32_t> v) {
if (++val_ == limit_) {
done.post();
} else {
at::launch([this, v = std::move(v)]() { (*this)(v); });
}
}
int val_{0};
int limit_;
Baton done;
};
Helper h(numIters);
std::vector<int32_t> v(vecSize, 0);
auto start = std::chrono::system_clock::now();
h(v);
h.done.wait();
std::cout << "WithData(" << vecSize << "): "
<< static_cast<double>(
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now() - start)
.count()) /
static_cast<double>(numIters)
<< " usec/each\n";
}
int main(int argc, char** argv) {
int32_t N = 1000000;
AtLaunch_Base(N);
AtLaunch_WithData(N, 0);
AtLaunch_WithData(N, 4);
AtLaunch_WithData(N, 256);
return 0;
}