mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix execution frame cleanup logic (#158717)
Summary: This fixes a bug in the execution fram cleanup logic - previously, whenever we hit the time interval to clear out the frames, we were removing any cached execution frames beyond the configured minimum number (frameEntry.used was unused). Instead, we only want to clear frames that were NOT USED in during the last time interval. This diff refactors the executor to have the correct logic. Test Plan: ``` buck2 test 'mode/dev-nosan' fbcode//sigmoid/inference/test_gpu:model_runner_test -- ModelRunnerTest.Basic_InterpreterCuda_Multithread_Cleanup --run-disabled --print-passing-details ``` Rollback Plan: Differential Revision: D78621408 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158717 Approved by: https://github.com/dolpm
This commit is contained in:
committed by
PyTorch MergeBot
parent
d7a855d67d
commit
c669b0ab87
@ -55,6 +55,15 @@ class MPMCQueue {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current size of the queue.
|
||||
* @return The number of elements in the queue.
|
||||
*/
|
||||
size_t size() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return storage_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::deque<T> storage_;
|
||||
|
@ -10,10 +10,6 @@
|
||||
#include <torch/nativert/kernels/C10Kernel.h>
|
||||
#include <torch/nativert/kernels/KernelFactory.h>
|
||||
|
||||
// Maximum number of retries when trying to get a frame from
|
||||
// clearedExecutionFrames_
|
||||
constexpr uint32_t kClearExecutionFrameRetries = 10;
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
Executor::Executor(
|
||||
@ -29,7 +25,7 @@ Executor::Executor(
|
||||
? std::optional<ConstantFolder>(*graph_)
|
||||
: std::nullopt),
|
||||
executionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||
clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||
inactiveExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||
numExecutionFrames_(0),
|
||||
lastClearedTimestamp_(getCurrentTimestampSeconds()) {
|
||||
if (weights) {
|
||||
@ -193,34 +189,12 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() {
|
||||
std::shared_ptr<Weights> weights;
|
||||
weights_.withLock([&](auto& w) { weights = w; });
|
||||
|
||||
// First try to get a frame from clearedExecutionFrames_ if clearing is in
|
||||
// progress
|
||||
if (C10_UNLIKELY(clearingInProgress_)) {
|
||||
ExecutionFrameEntry frameEntry;
|
||||
uint32_t retry = 0;
|
||||
while (
|
||||
retry <
|
||||
kClearExecutionFrameRetries) { // Limit retries to avoid infinite loop
|
||||
if (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) {
|
||||
if (retry > 0) {
|
||||
VLOG(1) << "Took " << retry
|
||||
<< " retries to pop from clearedExecutionFrames_";
|
||||
}
|
||||
ExecutorFramePtr ptr{std::move(frameEntry.frame), *this};
|
||||
if (ptr->weightVersion() != weights->version()) {
|
||||
ptr->setWeights(*weights);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
retry++;
|
||||
}
|
||||
// If we couldn't get a frame from cleared pool after retries, move onto
|
||||
// main pool
|
||||
}
|
||||
|
||||
// Try to get a frame from the main pool or create a new one
|
||||
std::unique_ptr<ExecutionFrame> frame;
|
||||
while (!executionFrames_.readIfNotEmpty(frame)) {
|
||||
|
||||
// Try to get a frame from executionFrames_ or inactiveExecutionFrames_
|
||||
while (!executionFrames_.readIfNotEmpty(frame) &&
|
||||
!inactiveExecutionFrames_.readIfNotEmpty(frame)) {
|
||||
int64_t numFrames = numExecutionFrames_.load();
|
||||
if (numFrames < executorConfig_.maxNumConcurrentThreads) {
|
||||
if (numExecutionFrames_.compare_exchange_strong(
|
||||
@ -243,6 +217,7 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() {
|
||||
}
|
||||
|
||||
void Executor::clearStaleExecutionFrames() {
|
||||
LOG(INFO) << "Clearing stale execution frames";
|
||||
if (!cleanupLock_.try_lock()) {
|
||||
// Another thread is already doing cleanup
|
||||
return;
|
||||
@ -250,40 +225,47 @@ void Executor::clearStaleExecutionFrames() {
|
||||
// Update timestamp first to minimize contention
|
||||
lastClearedTimestamp_ = getCurrentTimestampSeconds();
|
||||
|
||||
int numPopped = 0;
|
||||
// Get the size of active execution frames queue directly
|
||||
size_t activeFramesSize = executionFrames_.size();
|
||||
size_t inactiveFramesSize = inactiveExecutionFrames_.size();
|
||||
size_t total = activeFramesSize + inactiveFramesSize;
|
||||
size_t numCleared = 0;
|
||||
std::unique_ptr<ExecutionFrame> frame;
|
||||
|
||||
// Move frames from executionFrames_ to clearedExecutionFrames_
|
||||
while (executionFrames_.readIfNotEmpty(frame)) {
|
||||
++numPopped;
|
||||
// Keep the first popped entries up to minimum size
|
||||
if (numPopped > executorConfig_.minNumExecutionFrames) {
|
||||
// Discard stale frames
|
||||
// If number of active frames is less than the configured min, then transfer
|
||||
// the difference from inactive frames
|
||||
size_t minFramesToKeep = std::min(
|
||||
static_cast<size_t>(executorConfig_.minNumExecutionFrames), total);
|
||||
size_t framesToTransfer =
|
||||
(minFramesToKeep - activeFramesSize) > minFramesToKeep
|
||||
? static_cast<size_t>(0)
|
||||
: minFramesToKeep - activeFramesSize;
|
||||
;
|
||||
for (size_t i = 0;
|
||||
i < framesToTransfer && inactiveExecutionFrames_.readIfNotEmpty(frame);
|
||||
++i) {
|
||||
executionFrames_.writeIfNotFull(std::move(frame));
|
||||
}
|
||||
|
||||
size_t newActiveFramesSize = executionFrames_.size();
|
||||
|
||||
// Clear remaining inactive frames (i.e. those that were not used in the last
|
||||
// time interval)
|
||||
while (inactiveExecutionFrames_.readIfNotEmpty(frame)) {
|
||||
++numCleared;
|
||||
frame.reset();
|
||||
numExecutionFrames_ -= 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
ExecutionFrameEntry entry;
|
||||
entry.used = false;
|
||||
entry.frame = std::move(frame);
|
||||
clearedExecutionFrames_.writeIfNotFull(std::move(entry));
|
||||
// Enable clients to pop from clearedExecutionFrames_ while clearing is in
|
||||
// progress
|
||||
clearingInProgress_ = true;
|
||||
// Move active frames to inactive so they are cleared next time if not used
|
||||
// Check newActiveFramesSize > 0 to guuard against other threads adding
|
||||
// frames to active queue during while loop
|
||||
while (executionFrames_.readIfNotEmpty(frame) && newActiveFramesSize > 0) {
|
||||
--newActiveFramesSize;
|
||||
inactiveExecutionFrames_.writeIfNotFull(std::move(frame));
|
||||
}
|
||||
|
||||
uint32_t numPushed = 0;
|
||||
ExecutionFrameEntry frameEntry;
|
||||
// Move frames back from clearedExecutionFrames_ to executionFrames_
|
||||
while (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) {
|
||||
++numPushed;
|
||||
executionFrames_.writeIfNotFull(std::move(frameEntry.frame));
|
||||
clearingInProgress_ = false;
|
||||
}
|
||||
|
||||
clearingInProgress_ = false;
|
||||
VLOG(1) << "Cleared " << (numPopped - numPushed) << " out of " << numPopped
|
||||
LOG(INFO) << "Cleared " << numCleared << " out of " << total
|
||||
<< " ExecutionFrame instances in the pool";
|
||||
|
||||
cleanupLock_.unlock();
|
||||
@ -292,6 +274,8 @@ void Executor::clearStaleExecutionFrames() {
|
||||
void Executor::returnExecutorFrameToPool(
|
||||
std::unique_ptr<ExecutionFrame> frame) {
|
||||
// Check if it's time to clean up stale frames
|
||||
// TODO: consider moving cleanup to a dedicated thread so it does not impact
|
||||
// p99 latency
|
||||
if (executorConfig_.doExecutionFrameCleanup &&
|
||||
lastClearedTimestamp_ +
|
||||
executorConfig_.executionFramePoolCleanupIntervalSec <
|
||||
@ -301,21 +285,11 @@ void Executor::returnExecutorFrameToPool(
|
||||
|
||||
try {
|
||||
frame->destroyBorrowedIValues();
|
||||
|
||||
// Create an entry with used=true
|
||||
if (C10_UNLIKELY(!clearingInProgress_)) {
|
||||
// Always return to active execution frame pool, indicating that frame was
|
||||
// used in the previous time interval
|
||||
TORCH_CHECK(
|
||||
executionFrames_.writeIfNotFull(std::move(frame)),
|
||||
"ExecutionFrame pool full");
|
||||
} else {
|
||||
ExecutionFrameEntry frameEntry;
|
||||
frameEntry.used = true;
|
||||
frameEntry.frame = std::move(frame);
|
||||
|
||||
TORCH_CHECK(
|
||||
clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry)),
|
||||
"Cleared ExecutionFrame pool full");
|
||||
}
|
||||
} catch (...) {
|
||||
sem_.release();
|
||||
throw;
|
||||
|
@ -122,7 +122,7 @@ class Executor {
|
||||
std::vector<DelegateExecutor*> getDelegates();
|
||||
|
||||
// Get the number of execution frames in the pool
|
||||
int getNumExecutionFrames() const {
|
||||
auto getNumExecutionFrames() const {
|
||||
return numExecutionFrames_.load();
|
||||
}
|
||||
|
||||
@ -149,25 +149,6 @@ class Executor {
|
||||
void clearStaleExecutionFrames();
|
||||
|
||||
private:
|
||||
// Structure to track execution frame usage
|
||||
struct ExecutionFrameEntry {
|
||||
bool used{false};
|
||||
std::unique_ptr<ExecutionFrame> frame;
|
||||
|
||||
// Add move constructor and assignment operator
|
||||
ExecutionFrameEntry() = default;
|
||||
ExecutionFrameEntry(ExecutionFrameEntry&& other) noexcept
|
||||
: used(other.used), frame(std::move(other.frame)) {}
|
||||
ExecutionFrameEntry& operator=(ExecutionFrameEntry&& other) noexcept {
|
||||
used = other.used;
|
||||
frame = std::move(other.frame);
|
||||
return *this;
|
||||
}
|
||||
// Delete copy constructor and assignment operator
|
||||
ExecutionFrameEntry(const ExecutionFrameEntry&) = delete;
|
||||
ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete;
|
||||
};
|
||||
|
||||
void maybeRunConstantFolding(const std::shared_ptr<Weights>& weights);
|
||||
void validateInputs(const std::vector<c10::IValue>& inputs) const;
|
||||
|
||||
@ -188,8 +169,8 @@ class Executor {
|
||||
c10::Semaphore sem_;
|
||||
torch::nativert::detail::MPMCQueue<std::unique_ptr<ExecutionFrame>>
|
||||
executionFrames_;
|
||||
torch::nativert::detail::MPMCQueue<ExecutionFrameEntry>
|
||||
clearedExecutionFrames_;
|
||||
torch::nativert::detail::MPMCQueue<std::unique_ptr<ExecutionFrame>>
|
||||
inactiveExecutionFrames_;
|
||||
std::atomic_int64_t numExecutionFrames_;
|
||||
|
||||
std::unique_ptr<LayoutPlanner> layoutPlanner_;
|
||||
|
Reference in New Issue
Block a user