mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[nativert] Move ParallelGraphExecutor to PyTorch core (#156751)
Summary: `ParallelGraphExecutor` inherits from `GraphExecutorBase` and executes all nodes in the graph in a parallel manner Test Plan: CI Rollback Plan: Differential Revision: D77088996 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156751 Approved by: https://github.com/zhxchen17, https://github.com/dolpm
This commit is contained in:
committed by
PyTorch MergeBot
parent
44a5f93462
commit
6c008e2fb5
@ -614,6 +614,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/kernels/HigherOrderKernel.cpp",
|
||||
"torch/nativert/executor/memory/GreedyBySize.cpp",
|
||||
"torch/nativert/executor/memory/Bump.cpp",
|
||||
"torch/nativert/executor/ParallelGraphExecutor.cpp",
|
||||
"torch/nativert/kernels/CallTorchBindKernel.cpp",
|
||||
"torch/nativert/kernels/PrimKernelRegistry.cpp",
|
||||
]
|
||||
|
240
torch/nativert/executor/ParallelGraphExecutor.cpp
Normal file
240
torch/nativert/executor/ParallelGraphExecutor.cpp
Normal file
@ -0,0 +1,240 @@
|
||||
#include <moodycamel/concurrentqueue.h>
|
||||
#include <torch/nativert/executor/ExecutorConfig.h>
|
||||
#include <torch/nativert/executor/ParallelGraphExecutor.h>
|
||||
|
||||
namespace {
|
||||
|
||||
#define WITH_LOCK(m, block) \
|
||||
{ \
|
||||
std::unique_lock<decltype(m)> lk_(m); \
|
||||
block \
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
ThreadPoolExecutor::ThreadPoolExecutor()
|
||||
: work_(std::make_unique<moodycamel::ConcurrentQueue<Work>>()) {}
|
||||
|
||||
ThreadPoolExecutor::~ThreadPoolExecutor() {
|
||||
stop();
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() {
|
||||
thread_local moodycamel::ProducerToken ptok(*work_);
|
||||
return ptok;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() {
|
||||
thread_local moodycamel::ConsumerToken ctok(*work_);
|
||||
return ctok;
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) {
|
||||
session->addWork();
|
||||
unit->run(this, session);
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::start(int32_t numThreads) {
|
||||
stopped_ = false;
|
||||
for (int32_t i = 0; i < numThreads; ++i) {
|
||||
threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this));
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::loop() {
|
||||
while (true) {
|
||||
Work unit;
|
||||
|
||||
sem_->acquire();
|
||||
|
||||
if (stopped_) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (!work_->try_dequeue(ctok(), unit)) {
|
||||
};
|
||||
|
||||
unit();
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) {
|
||||
session->addWork();
|
||||
work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session));
|
||||
sem_->release();
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::add(
|
||||
SessionState* session,
|
||||
std::vector<WorkUnit*>::const_iterator&& begin,
|
||||
const std::vector<WorkUnit*>::const_iterator&& end) {
|
||||
const auto count = end - begin;
|
||||
|
||||
switch (count) {
|
||||
case 0: {
|
||||
return;
|
||||
}
|
||||
case 1: {
|
||||
return add(session, *begin);
|
||||
}
|
||||
}
|
||||
|
||||
session->addWork(count);
|
||||
|
||||
std::vector<Work> runnables;
|
||||
runnables.reserve(count);
|
||||
for (; begin != end; ++begin) {
|
||||
runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session));
|
||||
}
|
||||
|
||||
work_->enqueue_bulk(ptok(), runnables.begin(), count);
|
||||
sem_->release(count);
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::stop() {
|
||||
stopped_ = true;
|
||||
sem_->release(threads_.size());
|
||||
|
||||
std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); });
|
||||
threads_.clear();
|
||||
|
||||
{
|
||||
// reset sem
|
||||
auto tmp = std::make_unique<c10::Semaphore>();
|
||||
sem_.swap(tmp);
|
||||
}
|
||||
|
||||
{
|
||||
// flush queue
|
||||
auto tmp = moodycamel::ConcurrentQueue<Work>();
|
||||
work_->swap(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::run(
|
||||
SessionState& session,
|
||||
const std::vector<WorkUnit*>& roots) {
|
||||
// case where thread ptok exists but work_ was swapped
|
||||
if (auto& tok = ptok(); C10_UNLIKELY(!tok.valid())) {
|
||||
moodycamel::ProducerToken tmp(*work_);
|
||||
tok.swap(tmp);
|
||||
}
|
||||
|
||||
const auto rootCount = roots.size();
|
||||
|
||||
if (C10_UNLIKELY(rootCount == 0)) {
|
||||
return;
|
||||
} else if (C10_LIKELY(rootCount > 1)) {
|
||||
add(&session, roots.begin() + 1, roots.end());
|
||||
}
|
||||
|
||||
execute_inline(&session, roots[0]);
|
||||
|
||||
session.wait();
|
||||
}
|
||||
|
||||
void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) {
|
||||
thread_local std::vector<WorkUnit*> newWorkUnits;
|
||||
thread_local c10::InferenceMode mode;
|
||||
|
||||
WorkUnit* unit = this;
|
||||
|
||||
while (true) {
|
||||
unit->kernel->compute(session->frame());
|
||||
|
||||
for (auto* user : unit->users) {
|
||||
if (session->decrementProducers(user->node)) {
|
||||
newWorkUnits.push_back(user);
|
||||
}
|
||||
}
|
||||
|
||||
switch (newWorkUnits.size()) {
|
||||
case 0: {
|
||||
return session->removeWork();
|
||||
}
|
||||
case 1: {
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
executor->add(session, newWorkUnits[1]);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
executor->add(session, newWorkUnits.begin() + 1, newWorkUnits.end());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
unit = newWorkUnits[0];
|
||||
newWorkUnits.clear();
|
||||
}
|
||||
}
|
||||
|
||||
ParallelGraphExecutor::ParallelGraphExecutor(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const torch::nativert::ExecutorConfig& executorConfig)
|
||||
: GraphExecutorBase(graph, std::move(nodeKernels), executorConfig),
|
||||
workUnits_(
|
||||
graph.nodes().size() - 2 /* no need for prim.Input or Prim.Output */),
|
||||
graph_(graph) {
|
||||
auto& nodes = graph_.nodes();
|
||||
|
||||
auto input = &*nodes.begin();
|
||||
auto output = &*nodes.rbegin();
|
||||
|
||||
{
|
||||
// get rid of prim.Input and prim.Output kernels
|
||||
// since we won't be needing them
|
||||
nodeKernels_.erase(nodeKernels_.begin());
|
||||
nodeKernels_.pop_back();
|
||||
}
|
||||
|
||||
size_t idx = 0;
|
||||
for (const auto& node : nodes) {
|
||||
if (&node == input || &node == output) {
|
||||
continue;
|
||||
}
|
||||
auto& workUnit =
|
||||
nodeToWorkUnit_.insert_or_assign(&node, &workUnits_[idx]).first->second;
|
||||
workUnit->node = &node;
|
||||
workUnit->kernel = nodeKernels_[idx++].get();
|
||||
producers_.insert({&node, 0});
|
||||
}
|
||||
|
||||
for (auto& unit : workUnits_) {
|
||||
for (const auto* dep : unit.node->users()) {
|
||||
if (dep != output) {
|
||||
unit.users.push_back(nodeToWorkUnit_[dep]);
|
||||
producers_[dep] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& [node, p] : producers_) {
|
||||
if (p == 0) {
|
||||
inputWorkUnits_.push_back(nodeToWorkUnit_[node]);
|
||||
}
|
||||
}
|
||||
|
||||
executor_.start(executorConfig.maxParallelOps);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ParallelGraphExecutor::execute(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<c10::IValue> inputs) {
|
||||
fillUserInputs(executionFrame, std::move(inputs));
|
||||
return executeWithPrefilledFrame(executionFrame);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ParallelGraphExecutor::executeWithPrefilledFrame(
|
||||
ExecutionFrame& executionFrame) {
|
||||
auto session = SessionState(executionFrame, producers_);
|
||||
executor_.run(session, inputWorkUnits_);
|
||||
|
||||
return executionFrame.tryMoveUserOutputs();
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
94
torch/nativert/executor/ParallelGraphExecutor.h
Normal file
94
torch/nativert/executor/ParallelGraphExecutor.h
Normal file
@ -0,0 +1,94 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Semaphore.h>
|
||||
#include <torch/nativert/executor/GraphExecutorBase.h> // @manual
|
||||
#include <torch/nativert/executor/SessionState.h> // @manual
|
||||
#include <thread>
|
||||
|
||||
namespace moodycamel {
|
||||
struct ProducerToken;
|
||||
struct ConsumerToken;
|
||||
struct ConcurrentQueueDefaultTraits;
|
||||
template <typename T, typename Traits>
|
||||
class ConcurrentQueue;
|
||||
} // namespace moodycamel
|
||||
|
||||
namespace torch::nativert {
|
||||
class ThreadPoolExecutor;
|
||||
|
||||
typedef std::function<void()> Work;
|
||||
|
||||
struct WorkUnit {
|
||||
const Node* node;
|
||||
OpKernel* kernel;
|
||||
std::vector<WorkUnit*> users;
|
||||
void run(ThreadPoolExecutor* executor, SessionState* sessionState);
|
||||
};
|
||||
|
||||
class ThreadPoolExecutor {
|
||||
public:
|
||||
explicit ThreadPoolExecutor();
|
||||
~ThreadPoolExecutor();
|
||||
ThreadPoolExecutor(const ThreadPoolExecutor&) = delete;
|
||||
ThreadPoolExecutor& operator=(ThreadPoolExecutor const&) = delete;
|
||||
ThreadPoolExecutor(ThreadPoolExecutor&&) = delete;
|
||||
ThreadPoolExecutor& operator=(ThreadPoolExecutor&&) = delete;
|
||||
|
||||
void run(SessionState& session, const std::vector<WorkUnit*>& roots);
|
||||
|
||||
void start(int32_t numThreads);
|
||||
void stop();
|
||||
|
||||
// execute unit on the current thread
|
||||
// NOTE: children can still be offloaded to other threads
|
||||
C10_ALWAYS_INLINE void execute_inline(SessionState* session, WorkUnit* unit);
|
||||
|
||||
void add(SessionState* session, WorkUnit* unit);
|
||||
void add(
|
||||
SessionState* session,
|
||||
std::vector<WorkUnit*>::const_iterator&& begin,
|
||||
const std::vector<WorkUnit*>::const_iterator&& end);
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok();
|
||||
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok();
|
||||
|
||||
private:
|
||||
void loop();
|
||||
|
||||
std::atomic_bool stopped_{false};
|
||||
|
||||
std::unique_ptr<c10::Semaphore> sem_{std::make_unique<c10::Semaphore>()};
|
||||
|
||||
std::unique_ptr<moodycamel::ConcurrentQueue<
|
||||
Work,
|
||||
moodycamel::ConcurrentQueueDefaultTraits>>
|
||||
work_;
|
||||
std::vector<std::thread> threads_;
|
||||
};
|
||||
|
||||
class ParallelGraphExecutor : public GraphExecutorBase {
|
||||
public:
|
||||
ParallelGraphExecutor(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const torch::nativert::ExecutorConfig& executorConfig);
|
||||
|
||||
std::vector<c10::IValue> execute(
|
||||
ExecutionFrame& frame,
|
||||
std::vector<c10::IValue> inputs) override;
|
||||
|
||||
std::vector<c10::IValue> executeWithPrefilledFrame(
|
||||
ExecutionFrame& frame) override;
|
||||
|
||||
private:
|
||||
ThreadPoolExecutor executor_;
|
||||
|
||||
std::vector<WorkUnit*> inputWorkUnits_;
|
||||
c10::FastMap<const Node*, WorkUnit*> nodeToWorkUnit_;
|
||||
std::vector<WorkUnit> workUnits_;
|
||||
|
||||
const Graph& graph_;
|
||||
c10::FastMap<const Node*, copyable_atomic<std::uint_fast32_t>> producers_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user