Fix execution frame cleanup logic (#158717)

Summary: This fixes a bug in the execution fram cleanup logic - previously, whenever we hit the time interval to clear out the frames, we were removing any cached execution frames beyond the configured minimum number (frameEntry.used was unused). Instead, we only want to clear frames that were NOT USED in during the last time interval. This diff refactors the executor to have the correct logic.

Test Plan:
```
buck2 test 'mode/dev-nosan' fbcode//sigmoid/inference/test_gpu:model_runner_test -- ModelRunnerTest.Basic_InterpreterCuda_Multithread_Cleanup --run-disabled --print-passing-details
```

Rollback Plan:

Differential Revision: D78621408

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158717
Approved by: https://github.com/dolpm
This commit is contained in:
Georgia Phillips
2025-08-06 18:04:24 +00:00
committed by PyTorch MergeBot
parent d7a855d67d
commit c669b0ab87
3 changed files with 61 additions and 97 deletions

View File

@ -55,6 +55,15 @@ class MPMCQueue {
return true;
}
/**
* Get the current size of the queue.
* @return The number of elements in the queue.
*/
size_t size() {
std::lock_guard<std::mutex> lock(mutex_);
return storage_.size();
}
private:
std::mutex mutex_;
std::deque<T> storage_;

View File

@ -10,10 +10,6 @@
#include <torch/nativert/kernels/C10Kernel.h>
#include <torch/nativert/kernels/KernelFactory.h>
// Maximum number of retries when trying to get a frame from
// clearedExecutionFrames_
constexpr uint32_t kClearExecutionFrameRetries = 10;
namespace torch::nativert {
Executor::Executor(
@ -29,7 +25,7 @@ Executor::Executor(
? std::optional<ConstantFolder>(*graph_)
: std::nullopt),
executionFrames_(executorConfig_.maxNumConcurrentThreads),
clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
inactiveExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
numExecutionFrames_(0),
lastClearedTimestamp_(getCurrentTimestampSeconds()) {
if (weights) {
@ -193,34 +189,12 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() {
std::shared_ptr<Weights> weights;
weights_.withLock([&](auto& w) { weights = w; });
// First try to get a frame from clearedExecutionFrames_ if clearing is in
// progress
if (C10_UNLIKELY(clearingInProgress_)) {
ExecutionFrameEntry frameEntry;
uint32_t retry = 0;
while (
retry <
kClearExecutionFrameRetries) { // Limit retries to avoid infinite loop
if (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) {
if (retry > 0) {
VLOG(1) << "Took " << retry
<< " retries to pop from clearedExecutionFrames_";
}
ExecutorFramePtr ptr{std::move(frameEntry.frame), *this};
if (ptr->weightVersion() != weights->version()) {
ptr->setWeights(*weights);
}
return ptr;
}
retry++;
}
// If we couldn't get a frame from cleared pool after retries, move onto
// main pool
}
// Try to get a frame from the main pool or create a new one
std::unique_ptr<ExecutionFrame> frame;
while (!executionFrames_.readIfNotEmpty(frame)) {
// Try to get a frame from executionFrames_ or inactiveExecutionFrames_
while (!executionFrames_.readIfNotEmpty(frame) &&
!inactiveExecutionFrames_.readIfNotEmpty(frame)) {
int64_t numFrames = numExecutionFrames_.load();
if (numFrames < executorConfig_.maxNumConcurrentThreads) {
if (numExecutionFrames_.compare_exchange_strong(
@ -243,6 +217,7 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() {
}
void Executor::clearStaleExecutionFrames() {
LOG(INFO) << "Clearing stale execution frames";
if (!cleanupLock_.try_lock()) {
// Another thread is already doing cleanup
return;
@ -250,40 +225,47 @@ void Executor::clearStaleExecutionFrames() {
// Update timestamp first to minimize contention
lastClearedTimestamp_ = getCurrentTimestampSeconds();
int numPopped = 0;
// Get the size of active execution frames queue directly
size_t activeFramesSize = executionFrames_.size();
size_t inactiveFramesSize = inactiveExecutionFrames_.size();
size_t total = activeFramesSize + inactiveFramesSize;
size_t numCleared = 0;
std::unique_ptr<ExecutionFrame> frame;
// Move frames from executionFrames_ to clearedExecutionFrames_
while (executionFrames_.readIfNotEmpty(frame)) {
++numPopped;
// Keep the first popped entries up to minimum size
if (numPopped > executorConfig_.minNumExecutionFrames) {
// Discard stale frames
// If number of active frames is less than the configured min, then transfer
// the difference from inactive frames
size_t minFramesToKeep = std::min(
static_cast<size_t>(executorConfig_.minNumExecutionFrames), total);
size_t framesToTransfer =
(minFramesToKeep - activeFramesSize) > minFramesToKeep
? static_cast<size_t>(0)
: minFramesToKeep - activeFramesSize;
;
for (size_t i = 0;
i < framesToTransfer && inactiveExecutionFrames_.readIfNotEmpty(frame);
++i) {
executionFrames_.writeIfNotFull(std::move(frame));
}
size_t newActiveFramesSize = executionFrames_.size();
// Clear remaining inactive frames (i.e. those that were not used in the last
// time interval)
while (inactiveExecutionFrames_.readIfNotEmpty(frame)) {
++numCleared;
frame.reset();
numExecutionFrames_ -= 1;
continue;
}
ExecutionFrameEntry entry;
entry.used = false;
entry.frame = std::move(frame);
clearedExecutionFrames_.writeIfNotFull(std::move(entry));
// Enable clients to pop from clearedExecutionFrames_ while clearing is in
// progress
clearingInProgress_ = true;
// Move active frames to inactive so they are cleared next time if not used
// Check newActiveFramesSize > 0 to guuard against other threads adding
// frames to active queue during while loop
while (executionFrames_.readIfNotEmpty(frame) && newActiveFramesSize > 0) {
--newActiveFramesSize;
inactiveExecutionFrames_.writeIfNotFull(std::move(frame));
}
uint32_t numPushed = 0;
ExecutionFrameEntry frameEntry;
// Move frames back from clearedExecutionFrames_ to executionFrames_
while (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) {
++numPushed;
executionFrames_.writeIfNotFull(std::move(frameEntry.frame));
clearingInProgress_ = false;
}
clearingInProgress_ = false;
VLOG(1) << "Cleared " << (numPopped - numPushed) << " out of " << numPopped
LOG(INFO) << "Cleared " << numCleared << " out of " << total
<< " ExecutionFrame instances in the pool";
cleanupLock_.unlock();
@ -292,6 +274,8 @@ void Executor::clearStaleExecutionFrames() {
void Executor::returnExecutorFrameToPool(
std::unique_ptr<ExecutionFrame> frame) {
// Check if it's time to clean up stale frames
// TODO: consider moving cleanup to a dedicated thread so it does not impact
// p99 latency
if (executorConfig_.doExecutionFrameCleanup &&
lastClearedTimestamp_ +
executorConfig_.executionFramePoolCleanupIntervalSec <
@ -301,21 +285,11 @@ void Executor::returnExecutorFrameToPool(
try {
frame->destroyBorrowedIValues();
// Create an entry with used=true
if (C10_UNLIKELY(!clearingInProgress_)) {
// Always return to active execution frame pool, indicating that frame was
// used in the previous time interval
TORCH_CHECK(
executionFrames_.writeIfNotFull(std::move(frame)),
"ExecutionFrame pool full");
} else {
ExecutionFrameEntry frameEntry;
frameEntry.used = true;
frameEntry.frame = std::move(frame);
TORCH_CHECK(
clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry)),
"Cleared ExecutionFrame pool full");
}
} catch (...) {
sem_.release();
throw;

View File

@ -122,7 +122,7 @@ class Executor {
std::vector<DelegateExecutor*> getDelegates();
// Get the number of execution frames in the pool
int getNumExecutionFrames() const {
auto getNumExecutionFrames() const {
return numExecutionFrames_.load();
}
@ -149,25 +149,6 @@ class Executor {
void clearStaleExecutionFrames();
private:
// Structure to track execution frame usage
struct ExecutionFrameEntry {
bool used{false};
std::unique_ptr<ExecutionFrame> frame;
// Add move constructor and assignment operator
ExecutionFrameEntry() = default;
ExecutionFrameEntry(ExecutionFrameEntry&& other) noexcept
: used(other.used), frame(std::move(other.frame)) {}
ExecutionFrameEntry& operator=(ExecutionFrameEntry&& other) noexcept {
used = other.used;
frame = std::move(other.frame);
return *this;
}
// Delete copy constructor and assignment operator
ExecutionFrameEntry(const ExecutionFrameEntry&) = delete;
ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete;
};
void maybeRunConstantFolding(const std::shared_ptr<Weights>& weights);
void validateInputs(const std::vector<c10::IValue>& inputs) const;
@ -188,8 +169,8 @@ class Executor {
c10::Semaphore sem_;
torch::nativert::detail::MPMCQueue<std::unique_ptr<ExecutionFrame>>
executionFrames_;
torch::nativert::detail::MPMCQueue<ExecutionFrameEntry>
clearedExecutionFrames_;
torch::nativert::detail::MPMCQueue<std::unique_ptr<ExecutionFrame>>
inactiveExecutionFrames_;
std::atomic_int64_t numExecutionFrames_;
std::unique_ptr<LayoutPlanner> layoutPlanner_;