Files
pytorch/torch/nativert/executor/ParallelGraphExecutor.cpp
dolpm 725c327284 [nativert] add memory overlap debug assertion (#157290)
Summary: better safe than sorry. will throw if memory overlap detected when using planned tensors and debug mode is enabled -- this will make our planning unit tests more robust.

Test Plan:
ci

Rollback Plan:

Differential Revision: D77327841

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157290
Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
2025-07-14 19:12:41 +00:00

245 lines
5.9 KiB
C++

#include <moodycamel/concurrentqueue.h>
#include <torch/nativert/executor/ExecutorConfig.h>
#include <torch/nativert/executor/ParallelGraphExecutor.h>
namespace {
#define WITH_LOCK(m, block) \
{ \
std::unique_lock<decltype(m)> lk_(m); \
block \
}
} // namespace
namespace torch::nativert {
ThreadPoolExecutor::ThreadPoolExecutor()
: work_(std::make_unique<moodycamel::ConcurrentQueue<Work>>()) {}
ThreadPoolExecutor::~ThreadPoolExecutor() {
stop();
}
C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() {
// NOLINTNEXTLINE(misc-use-internal-linkage)
thread_local moodycamel::ProducerToken ptok(*work_);
return ptok;
}
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() {
// NOLINTNEXTLINE(misc-use-internal-linkage)
thread_local moodycamel::ConsumerToken ctok(*work_);
return ctok;
}
void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) {
session->addWork();
unit->run(this, session);
}
void ThreadPoolExecutor::start(int32_t numThreads) {
stopped_ = false;
for (int32_t i = 0; i < numThreads; ++i) {
threads_.emplace_back(&ThreadPoolExecutor::loop, this);
}
}
void ThreadPoolExecutor::loop() {
while (true) {
Work unit;
sem_->acquire();
if (stopped_) {
return;
}
while (!work_->try_dequeue(ctok(), unit)) {
};
unit();
}
}
void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) {
session->addWork();
work_->enqueue(ptok(), [unit, this, session] { unit->run(this, session); });
sem_->release();
}
void ThreadPoolExecutor::add(
SessionState* session,
std::vector<WorkUnit*>::const_iterator begin,
const std::vector<WorkUnit*>::const_iterator& end) {
const auto count = end - begin;
// NOLINTNEXTLINE(bugprone-switch-missing-default-case)
switch (count) {
case 0: {
return;
}
case 1: {
return add(session, *begin);
}
}
session->addWork(count);
std::vector<Work> runnables;
runnables.reserve(count);
for (; begin != end; ++begin) {
runnables.emplace_back(
[capture0 = *begin, this, session] { capture0->run(this, session); });
}
work_->enqueue_bulk(ptok(), runnables.begin(), count);
sem_->release(static_cast<int32_t>(count));
}
void ThreadPoolExecutor::stop() {
stopped_ = true;
sem_->release(static_cast<int32_t>(threads_.size()));
std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); });
threads_.clear();
{
// reset sem
auto tmp = std::make_unique<c10::Semaphore>();
sem_.swap(tmp);
}
{
// flush queue
auto tmp = moodycamel::ConcurrentQueue<Work>();
work_->swap(tmp);
}
}
void ThreadPoolExecutor::run(
SessionState& session,
const std::vector<WorkUnit*>& roots) {
// case where thread ptok exists but work_ was swapped
if (auto& tok = ptok(); C10_UNLIKELY(!tok.valid())) {
moodycamel::ProducerToken tmp(*work_);
tok.swap(tmp);
}
const auto rootCount = roots.size();
if (C10_UNLIKELY(rootCount == 0)) {
return;
} else if (C10_LIKELY(rootCount > 1)) {
add(&session, roots.begin() + 1, roots.end());
}
execute_inline(&session, roots[0]);
session.wait();
}
void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) {
/* thread_local */ std::vector<WorkUnit*> newWorkUnits;
/* thread_local */ c10::InferenceMode mode;
/* thread_local */ WorkUnit* unit = this;
while (true) {
unit->kernel->compute(session->frame());
for (auto* user : unit->users) {
if (session->decrementProducers(user->node)) {
newWorkUnits.push_back(user);
}
}
switch (newWorkUnits.size()) {
case 0: {
return session->removeWork();
}
case 1: {
break;
}
case 2: {
executor->add(session, newWorkUnits[1]);
break;
}
default: {
executor->add(session, newWorkUnits.begin() + 1, newWorkUnits.end());
break;
}
}
unit = newWorkUnits[0];
newWorkUnits.clear();
}
}
ParallelGraphExecutor::ParallelGraphExecutor(
const Graph& graph,
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
const ExecutorConfig& executorConfig)
: GraphExecutorBase(graph, std::move(nodeKernels), executorConfig),
workUnits_(
graph.nodes().size() - 2 /* no need for prim.Input or Prim.Output */),
graph_(graph) {
auto& nodes = graph_.nodes();
auto input = &*nodes.begin();
auto output = &*nodes.rbegin();
{
// get rid of prim.Input and prim.Output kernels
// since we won't be needing them
nodeKernels_.erase(nodeKernels_.begin());
nodeKernels_.pop_back();
}
size_t idx = 0;
for (const auto& node : nodes) {
if (&node == input || &node == output) {
continue;
}
auto& workUnit =
nodeToWorkUnit_.insert_or_assign(&node, &workUnits_[idx]).first->second;
workUnit->node = &node;
workUnit->kernel = nodeKernels_[idx++].get();
producers_.insert({&node, 0});
}
for (auto& unit : workUnits_) {
for (const auto* dep : unit.node->users()) {
if (dep != output) {
unit.users.push_back(nodeToWorkUnit_[dep]);
producers_[dep] += 1;
}
}
}
for (auto& [node, p] : producers_) {
if (p == 0) {
inputWorkUnits_.push_back(nodeToWorkUnit_[node]);
}
}
executor_.start(static_cast<int32_t>(executorConfig.maxParallelOps));
}
std::vector<c10::IValue> ParallelGraphExecutor::execute(
ExecutionFrame& executionFrame,
std::vector<c10::IValue> inputs) {
fillUserInputs(executionFrame, std::move(inputs));
return executeWithPrefilledFrame(executionFrame);
}
std::vector<c10::IValue> ParallelGraphExecutor::executeWithPrefilledFrame(
ExecutionFrame& executionFrame) {
auto session = SessionState(executionFrame, producers_);
executor_.run(session, inputWorkUnits_);
return executionFrame.tryMoveUserOutputs();
}
} // namespace torch::nativert