mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Move Executor to PyTorch core (#157514)
Test Plan: CI Rollback Plan: Differential Revision: D77693984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157514 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad86c05b78
commit
f7130c097e
@ -601,6 +601,7 @@ libtorch_nativert_sources = [
|
|||||||
"torch/nativert/executor/Placement.cpp",
|
"torch/nativert/executor/Placement.cpp",
|
||||||
"torch/nativert/executor/ExecutionPlanner.cpp",
|
"torch/nativert/executor/ExecutionPlanner.cpp",
|
||||||
"torch/nativert/executor/ExecutionFrame.cpp",
|
"torch/nativert/executor/ExecutionFrame.cpp",
|
||||||
|
"torch/nativert/executor/Executor.cpp",
|
||||||
"torch/nativert/executor/GraphExecutorBase.cpp",
|
"torch/nativert/executor/GraphExecutorBase.cpp",
|
||||||
"torch/nativert/executor/ConstantFolder.cpp",
|
"torch/nativert/executor/ConstantFolder.cpp",
|
||||||
"torch/nativert/executor/OpKernel.cpp",
|
"torch/nativert/executor/OpKernel.cpp",
|
||||||
|
387
torch/nativert/executor/Executor.cpp
Normal file
387
torch/nativert/executor/Executor.cpp
Normal file
@ -0,0 +1,387 @@
|
|||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <c10/util/Enumerate.h>
|
||||||
|
#include <c10/util/Synchronized.h>
|
||||||
|
#include <torch/nativert/executor/ExecutionFrame.h>
|
||||||
|
#include <torch/nativert/executor/Executor.h>
|
||||||
|
#include <torch/nativert/executor/ParallelGraphExecutor.h>
|
||||||
|
#include <torch/nativert/executor/SerialGraphExecutor.h>
|
||||||
|
#include <torch/nativert/executor/Weights.h>
|
||||||
|
#include <torch/nativert/kernels/C10Kernel.h>
|
||||||
|
#include <torch/nativert/kernels/KernelFactory.h>
|
||||||
|
|
||||||
|
// Maximum number of retries when trying to get a frame from
|
||||||
|
// clearedExecutionFrames_
|
||||||
|
constexpr uint32_t kClearExecutionFrameRetries = 10;
|
||||||
|
|
||||||
|
namespace torch::nativert {
|
||||||
|
|
||||||
|
Executor::Executor(
|
||||||
|
torch::nativert::ExecutorConfig executorConfig,
|
||||||
|
std::shared_ptr<Graph> graph,
|
||||||
|
std::shared_ptr<Weights> weights,
|
||||||
|
const Placement& placement,
|
||||||
|
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||||
|
const MakeProxyExecutorFn& makeProxyExecutorFunc)
|
||||||
|
: executorConfig_(std::move(executorConfig)),
|
||||||
|
graph_(std::move(graph)),
|
||||||
|
placement_(placement),
|
||||||
|
constantFolder_(
|
||||||
|
executorConfig_.runConstFolding
|
||||||
|
? std::optional<ConstantFolder>(*graph_)
|
||||||
|
: std::nullopt),
|
||||||
|
makeProxyExecutorFunc_(makeProxyExecutorFunc),
|
||||||
|
executionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||||
|
clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||||
|
numExecutionFrames_(0),
|
||||||
|
lastClearedTimestamp_(getCurrentTimestampSeconds()) {
|
||||||
|
if (weights) {
|
||||||
|
initialize(std::move(weights), std::move(pytorchStreamReader));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Executor::initialize(
|
||||||
|
std::shared_ptr<Weights> weights,
|
||||||
|
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||||
|
pytorchStreamReader) {
|
||||||
|
auto start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
auto executionKernels = KernelFactory().initializeNodeKernels(
|
||||||
|
*graph_,
|
||||||
|
weights,
|
||||||
|
executorConfig_,
|
||||||
|
placement_,
|
||||||
|
std::move(pytorchStreamReader),
|
||||||
|
makeProxyExecutorFunc_);
|
||||||
|
|
||||||
|
if (constantFolder_.has_value()) {
|
||||||
|
constantFolder_->unlinkConstants(executionKernels.nodeKernels);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& kernelSchemas = getKernelSchemas(executionKernels.nodeKernels);
|
||||||
|
|
||||||
|
if (executorConfig_.maxParallelOps > 1) {
|
||||||
|
graphExecutor_ = std::make_unique<ParallelGraphExecutor>(
|
||||||
|
*graph_, std::move(executionKernels.nodeKernels), executorConfig_);
|
||||||
|
} else {
|
||||||
|
graphExecutor_ = std::make_unique<torch::nativert::SerialGraphExecutor>(
|
||||||
|
*graph_, std::move(executionKernels.nodeKernels), executorConfig_);
|
||||||
|
}
|
||||||
|
|
||||||
|
delegateExecutors_ = std::move(executionKernels.delegateExecutors);
|
||||||
|
constFoldingExecutions_ = std::move(executionKernels.constFoldingExecutions);
|
||||||
|
|
||||||
|
// initialize weights_
|
||||||
|
processWeights(weights);
|
||||||
|
atomicSwapWeights(weights);
|
||||||
|
|
||||||
|
if (executorConfig_.layoutPlannerSettings.enabled()) {
|
||||||
|
layoutPlanner_ = std::make_unique<LayoutPlanner>(
|
||||||
|
*graph_,
|
||||||
|
kernelSchemas,
|
||||||
|
ExecutionFrame::getPersistentValueMask(*graph_, weights.get()),
|
||||||
|
executorConfig_.layoutPlannerSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto end = std::chrono::high_resolution_clock::now();
|
||||||
|
LOG(INFO) << "Initialization completed in "
|
||||||
|
<< std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
end - start)
|
||||||
|
.count()
|
||||||
|
<< " ms";
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static */ c10::
|
||||||
|
FastMap<std::string /* target */, torch::nativert::FunctionSchema>
|
||||||
|
Executor::getKernelSchemas(
|
||||||
|
const std::vector<std::unique_ptr<OpKernel>>& kernels) {
|
||||||
|
c10::FastMap<std::string, torch::nativert::FunctionSchema> output;
|
||||||
|
for (const auto& kernel : kernels) {
|
||||||
|
if (const auto* casted = dynamic_cast<C10Kernel*>(kernel.get()); casted) {
|
||||||
|
output.insert({std::string(kernel->node()->target()), casted->schema()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Executor::atomicSwapWeights(std::shared_ptr<Weights> weights) {
|
||||||
|
weights_.withLock([&](auto& w) { w = std::move(weights); });
|
||||||
|
|
||||||
|
// update weights in delegate executors
|
||||||
|
for (auto& delegateExecutor : delegateExecutors_) {
|
||||||
|
delegateExecutor->commitWeights();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
|
||||||
|
for (auto& execution : constFoldingExecutions_) {
|
||||||
|
ExecutionFrame constFoldingFrame(execution.executor->graph());
|
||||||
|
std::vector<c10::IValue> inputs;
|
||||||
|
inputs.reserve(graph_->signature().inputsToWeights().size());
|
||||||
|
for (const auto& [_, name] : graph_->signature().inputsToWeights()) {
|
||||||
|
inputs.push_back(weights->at(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outputs = execution.executor->execute(constFoldingFrame, inputs);
|
||||||
|
for (const auto& [idx, value] :
|
||||||
|
c10::enumerate(execution.executor->graph().outputs())) {
|
||||||
|
weights->updateFoldedConst(value->name(), outputs.at(idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Executor::processWeights(std::shared_ptr<Weights> weights) {
|
||||||
|
maybeRunConstantFolding(weights);
|
||||||
|
if (constantFolder_.has_value()) {
|
||||||
|
constantFolder_->evaluate(*weights);
|
||||||
|
}
|
||||||
|
for (auto& delegateExecutor : delegateExecutors_) {
|
||||||
|
delegateExecutor->processWeights(weights);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void validateInput(
|
||||||
|
const std::string& inputName,
|
||||||
|
const at::Tensor& inputTensor,
|
||||||
|
const torch::nativert::TensorMeta& tensorValueMeta) {
|
||||||
|
CHECK(inputTensor.dtype() == tensorValueMeta.dtype())
|
||||||
|
<< "Input tensor dtype mismatch for " << inputName << ", expecting "
|
||||||
|
<< c10::toString(tensorValueMeta.dtype()) << " but got "
|
||||||
|
<< inputTensor.dtype().name();
|
||||||
|
|
||||||
|
CHECK(inputTensor.device() == tensorValueMeta.device())
|
||||||
|
<< "Input tensor device mismatch for " << inputName << ", expecting "
|
||||||
|
<< tensorValueMeta.device().str() << " but got "
|
||||||
|
<< inputTensor.device().str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// validate input tensor's dtype matches tensorMeta
|
||||||
|
void Executor::validateInputs(const std::vector<c10::IValue>& inputs) const {
|
||||||
|
const auto& inputValues = graph_->userInputs();
|
||||||
|
const auto& tensorValuesMeta = graph_->tensorValuesMeta();
|
||||||
|
TORCH_CHECK(inputs.size() == inputValues.size(), "Input size mismatch");
|
||||||
|
for (auto&& [i, actualInput] : c10::enumerate(inputs)) {
|
||||||
|
if (actualInput.isTensor()) {
|
||||||
|
const auto& inputName = std::string(inputValues[i]->name());
|
||||||
|
auto it = tensorValuesMeta.find(inputName);
|
||||||
|
CHECK(it != tensorValuesMeta.end())
|
||||||
|
<< "Couldn't find " << inputName << " in tensorValuesMeta";
|
||||||
|
validateInput(inputName, actualInput.toTensor(), it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() {
|
||||||
|
std::shared_ptr<Weights> weights;
|
||||||
|
weights_.withLock([&](auto& w) { weights = w; });
|
||||||
|
|
||||||
|
// First try to get a frame from clearedExecutionFrames_ if clearing is in
|
||||||
|
// progress
|
||||||
|
if (C10_UNLIKELY(clearingInProgress_)) {
|
||||||
|
ExecutionFrameEntry frameEntry;
|
||||||
|
uint32_t retry = 0;
|
||||||
|
while (
|
||||||
|
retry <
|
||||||
|
kClearExecutionFrameRetries) { // Limit retries to avoid infinite loop
|
||||||
|
if (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) {
|
||||||
|
if (retry > 0) {
|
||||||
|
VLOG(1) << "Took " << retry
|
||||||
|
<< " retries to pop from clearedExecutionFrames_";
|
||||||
|
}
|
||||||
|
ExecutorFramePtr ptr{std::move(frameEntry.frame), *this};
|
||||||
|
if (ptr->weightVersion() != weights->version()) {
|
||||||
|
ptr->setWeights(*weights);
|
||||||
|
}
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
retry++;
|
||||||
|
}
|
||||||
|
// If we couldn't get a frame from cleared pool after retries, move onto
|
||||||
|
// main pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get a frame from the main pool or create a new one
|
||||||
|
std::unique_ptr<ExecutionFrame> frame;
|
||||||
|
while (!executionFrames_.readIfNotEmpty(frame)) {
|
||||||
|
int64_t numFrames = numExecutionFrames_.load();
|
||||||
|
if (numFrames < executorConfig_.maxNumConcurrentThreads) {
|
||||||
|
if (numExecutionFrames_.compare_exchange_strong(
|
||||||
|
numFrames, numFrames + 1)) {
|
||||||
|
return ExecutorFramePtr{
|
||||||
|
std::make_unique<ExecutionFrame>(
|
||||||
|
*graph_, *weights, executorConfig_, layoutPlanner_.get()),
|
||||||
|
*this};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sem_.acquire();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExecutorFramePtr ptr{std::move(frame), *this};
|
||||||
|
|
||||||
|
if (ptr->weightVersion() != weights->version()) {
|
||||||
|
ptr->setWeights(*weights);
|
||||||
|
}
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Executor::clearStaleExecutionFrames() {
|
||||||
|
if (!cleanupLock_.try_lock()) {
|
||||||
|
// Another thread is already doing cleanup
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Update timestamp first to minimize contention
|
||||||
|
lastClearedTimestamp_ = getCurrentTimestampSeconds();
|
||||||
|
|
||||||
|
int numPopped = 0;
|
||||||
|
std::unique_ptr<ExecutionFrame> frame;
|
||||||
|
|
||||||
|
// Move frames from executionFrames_ to clearedExecutionFrames_
|
||||||
|
while (executionFrames_.readIfNotEmpty(frame)) {
|
||||||
|
++numPopped;
|
||||||
|
// Keep the first popped entries up to minimum size
|
||||||
|
if (numPopped > executorConfig_.minNumExecutionFrames) {
|
||||||
|
// Discard stale frames
|
||||||
|
frame.reset();
|
||||||
|
numExecutionFrames_ -= 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutionFrameEntry entry;
|
||||||
|
entry.used = false;
|
||||||
|
entry.frame = std::move(frame);
|
||||||
|
clearedExecutionFrames_.writeIfNotFull(std::move(entry));
|
||||||
|
// Enable clients to pop from clearedExecutionFrames_ while clearing is in
|
||||||
|
// progress
|
||||||
|
clearingInProgress_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t numPushed = 0;
|
||||||
|
ExecutionFrameEntry frameEntry;
|
||||||
|
// Move frames back from clearedExecutionFrames_ to executionFrames_
|
||||||
|
while (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) {
|
||||||
|
++numPushed;
|
||||||
|
executionFrames_.writeIfNotFull(std::move(frameEntry.frame));
|
||||||
|
clearingInProgress_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
clearingInProgress_ = false;
|
||||||
|
VLOG(1) << "Cleared " << (numPopped - numPushed) << " out of " << numPopped
|
||||||
|
<< " ExecutionFrame instances in the pool";
|
||||||
|
|
||||||
|
cleanupLock_.unlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Executor::returnExecutorFrameToPool(
|
||||||
|
std::unique_ptr<ExecutionFrame> frame) {
|
||||||
|
// Check if it's time to clean up stale frames
|
||||||
|
if (executorConfig_.doExecutionFrameCleanup &&
|
||||||
|
lastClearedTimestamp_ +
|
||||||
|
executorConfig_.executionFramePoolCleanupIntervalSec <
|
||||||
|
getCurrentTimestampSeconds()) {
|
||||||
|
clearStaleExecutionFrames();
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
frame->destroyBorrowedIValues();
|
||||||
|
|
||||||
|
// Create an entry with used=true
|
||||||
|
if (C10_UNLIKELY(!clearingInProgress_)) {
|
||||||
|
CHECK(executionFrames_.writeIfNotFull(std::move(frame)))
|
||||||
|
<< "ExecutionFrame pool full";
|
||||||
|
} else {
|
||||||
|
ExecutionFrameEntry frameEntry;
|
||||||
|
frameEntry.used = true;
|
||||||
|
frameEntry.frame = std::move(frame);
|
||||||
|
|
||||||
|
CHECK(clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry)))
|
||||||
|
<< "Cleared ExecutionFrame pool full";
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
sem_.release();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
sem_.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<c10::IValue> Executor::execute(std::vector<c10::IValue> inputs) {
|
||||||
|
if (executorConfig_.validateInputs) {
|
||||||
|
validateInputs(inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto executionFrame = getExecutorFrameFromPool();
|
||||||
|
return graphExecutor_->execute(*executionFrame, std::move(inputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<c10::IValue> Executor::execute(
|
||||||
|
const std::vector<c10::IValue>& args,
|
||||||
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||||
|
const ITreeSpec& inputTreeSpec) {
|
||||||
|
auto executionFrame = getExecutorFrameFromPool();
|
||||||
|
|
||||||
|
std::optional<std::vector<c10::IValue>> outputs;
|
||||||
|
const auto userInputs = graph_->userInputs();
|
||||||
|
const auto& tensorValuesMeta = graph_->tensorValuesMeta();
|
||||||
|
TORCH_CHECK_EQ(userInputs.size(), inputTreeSpec.numIValues());
|
||||||
|
|
||||||
|
auto executionFrameFillUserInputs = [&](const c10::IValue& leaf,
|
||||||
|
const Value* value) {
|
||||||
|
// validate input tensor's dtype and device matches tensorMeta
|
||||||
|
if (executorConfig_.validateInputs && leaf.isTensor()) {
|
||||||
|
const auto& inputName = std::string(value->name());
|
||||||
|
auto it = tensorValuesMeta.find(inputName);
|
||||||
|
CHECK(it != tensorValuesMeta.end())
|
||||||
|
<< "Couldn't find " << inputName << " in tensorValuesMeta";
|
||||||
|
validateInput(inputName, leaf.toTensor(), it->second);
|
||||||
|
}
|
||||||
|
executionFrame->setBorrowedIValue(
|
||||||
|
value->id(), c10::MaybeOwnedTraits<c10::IValue>::createBorrow(leaf));
|
||||||
|
};
|
||||||
|
ivalueApplyFromArgs(
|
||||||
|
executionFrameFillUserInputs, args, kwargs, inputTreeSpec);
|
||||||
|
try {
|
||||||
|
outputs = graphExecutor_->executeWithPrefilledFrame(*executionFrame);
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG(ERROR) << "Exception during executeWithPrefilledFrame: " << e.what();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::move(*outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
ProfileMetrics Executor::benchmarkIndividualNodes(
|
||||||
|
std::vector<std::vector<c10::IValue>> inputsList,
|
||||||
|
const uint32_t warmupRuns,
|
||||||
|
const uint32_t mainRuns) {
|
||||||
|
CHECK(inputsList.size() > 0) << "Need at least one input to benchmark";
|
||||||
|
CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run";
|
||||||
|
|
||||||
|
for (const auto& inputs : inputsList) {
|
||||||
|
if (executorConfig_.validateInputs) {
|
||||||
|
validateInputs(inputs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto executionFrame = getExecutorFrameFromPool();
|
||||||
|
auto benchmarkResults = graphExecutor_->benchmarkIndividualNodes(
|
||||||
|
*executionFrame, inputsList, warmupRuns, mainRuns);
|
||||||
|
|
||||||
|
return benchmarkResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Executor::getCurrentTimestampSeconds() const {
|
||||||
|
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||||
|
std::chrono::steady_clock::now().time_since_epoch())
|
||||||
|
.count();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<DelegateExecutor*> Executor::getDelegates() {
|
||||||
|
std::vector<DelegateExecutor*> delegates;
|
||||||
|
for (const auto& delegateExecutor : delegateExecutors_) {
|
||||||
|
delegates.push_back(delegateExecutor.get());
|
||||||
|
}
|
||||||
|
return delegates;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::nativert
|
206
torch/nativert/executor/Executor.h
Normal file
206
torch/nativert/executor/Executor.h
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <c10/util/FbcodeMaps.h>
|
||||||
|
#include <c10/util/Logging.h>
|
||||||
|
#include <c10/util/Semaphore.h>
|
||||||
|
#include <c10/util/Synchronized.h>
|
||||||
|
|
||||||
|
#include <torch/nativert/detail/ITree.h>
|
||||||
|
#include <torch/nativert/detail/MPMCQueue.h>
|
||||||
|
#include <torch/nativert/executor/ConstantFolder.h>
|
||||||
|
#include <torch/nativert/executor/DelegateExecutor.h>
|
||||||
|
#include <torch/nativert/executor/ExecutionPlanner.h>
|
||||||
|
#include <torch/nativert/executor/ExecutorConfig.h>
|
||||||
|
#include <torch/nativert/executor/GraphExecutorBase.h>
|
||||||
|
#include <torch/nativert/executor/Placement.h>
|
||||||
|
#include <torch/nativert/executor/memory/FunctionSchema.h>
|
||||||
|
#include <torch/nativert/executor/memory/LayoutPlanner.h>
|
||||||
|
#include <torch/nativert/graph/Graph.h>
|
||||||
|
#include <torch/nativert/graph/GraphSignature.h>
|
||||||
|
#include <torch/nativert/kernels/KernelFactory.h>
|
||||||
|
|
||||||
|
namespace torch::nativert {
|
||||||
|
|
||||||
|
using namespace torch::nativert::detail;
|
||||||
|
|
||||||
|
struct DistributedRunConfig;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A very dumb executor. Basically just runs each node in order and contains a
|
||||||
|
* giant unordered map for every intermediate, no optimizations applied.
|
||||||
|
*/
|
||||||
|
class Executor {
|
||||||
|
class ExecutorFrameDeleter {
|
||||||
|
public:
|
||||||
|
explicit ExecutorFrameDeleter(Executor& e) : e_(&e) {}
|
||||||
|
ExecutorFrameDeleter(ExecutorFrameDeleter&&) = default;
|
||||||
|
ExecutorFrameDeleter& operator=(ExecutorFrameDeleter&&) = default;
|
||||||
|
ExecutorFrameDeleter(const ExecutorFrameDeleter&) = default;
|
||||||
|
ExecutorFrameDeleter& operator=(const ExecutorFrameDeleter&) = default;
|
||||||
|
~ExecutorFrameDeleter() = default;
|
||||||
|
|
||||||
|
void operator()(ExecutionFrame* p) {
|
||||||
|
e_->returnExecutorFrameToPool(std::unique_ptr<ExecutionFrame>(p));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Executor* e_;
|
||||||
|
};
|
||||||
|
class ExecutorFramePtr {
|
||||||
|
public:
|
||||||
|
ExecutorFramePtr(std::unique_ptr<ExecutionFrame> ptr, Executor& e)
|
||||||
|
: ptr_(std::unique_ptr<ExecutionFrame, ExecutorFrameDeleter>(
|
||||||
|
ptr.release(),
|
||||||
|
ExecutorFrameDeleter{e})) {}
|
||||||
|
ExecutorFramePtr() = delete;
|
||||||
|
ExecutorFramePtr(ExecutorFramePtr&&) = default;
|
||||||
|
ExecutorFramePtr& operator=(ExecutorFramePtr&&) = default;
|
||||||
|
ExecutorFramePtr(const ExecutorFramePtr&) = delete;
|
||||||
|
ExecutorFramePtr& operator=(const ExecutorFramePtr&) = delete;
|
||||||
|
~ExecutorFramePtr() = default;
|
||||||
|
|
||||||
|
ExecutionFrame& operator*() {
|
||||||
|
return *ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutionFrame* operator->() {
|
||||||
|
return ptr_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ExecutionFrame, ExecutorFrameDeleter> ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Constrcutor used for Inference Path
|
||||||
|
Executor(
|
||||||
|
torch::nativert::ExecutorConfig executorConfig,
|
||||||
|
std::shared_ptr<Graph> graph,
|
||||||
|
std::shared_ptr<Weights> weights,
|
||||||
|
const Placement& placement = Placement(),
|
||||||
|
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||||
|
pytorchStreamReader = nullptr,
|
||||||
|
const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<Weights> getWeights() {
|
||||||
|
std::shared_ptr<Weights> ret;
|
||||||
|
weights_.withLock([&](auto& w) { ret = w; });
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void processWeights(std::shared_ptr<Weights> weights);
|
||||||
|
void atomicSwapWeights(std::shared_ptr<Weights> weights);
|
||||||
|
|
||||||
|
// This API only returns the flattened UserOutputs,
|
||||||
|
// intended to be used for Inference path
|
||||||
|
// TODO Investigate whether we should remove this, still seems
|
||||||
|
// useful for testing.
|
||||||
|
std::vector<c10::IValue> execute(std::vector<c10::IValue> inputs);
|
||||||
|
|
||||||
|
std::vector<c10::IValue> execute(
|
||||||
|
const std::vector<c10::IValue>& args,
|
||||||
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||||
|
const ITreeSpec& inputTreeSpec);
|
||||||
|
|
||||||
|
ProfileMetrics benchmarkIndividualNodes(
|
||||||
|
std::vector<std::vector<c10::IValue>> inputsList,
|
||||||
|
const uint32_t warmupRuns,
|
||||||
|
const uint32_t mainRuns);
|
||||||
|
|
||||||
|
const torch::nativert::GraphSignature& graphSignature() const {
|
||||||
|
return graph_->signature();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string className() {
|
||||||
|
return "Executor.v0";
|
||||||
|
}
|
||||||
|
|
||||||
|
const torch::nativert::ExecutorConfig& executorConfig() const {
|
||||||
|
return executorConfig_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<DelegateExecutor*> getDelegates();
|
||||||
|
|
||||||
|
// Get the number of execution frames in the pool
|
||||||
|
int getNumExecutionFrames() const {
|
||||||
|
return numExecutionFrames_.load();
|
||||||
|
}
|
||||||
|
|
||||||
|
static c10::FastMap<std::string /* target */, torch::nativert::FunctionSchema>
|
||||||
|
getKernelSchemas(const std::vector<std::unique_ptr<OpKernel>>& kernels);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
torch::nativert::ExecutorConfig executorConfig_;
|
||||||
|
|
||||||
|
std::shared_ptr<Graph> graph_;
|
||||||
|
|
||||||
|
// manages the parameters, buffers and tensor constants
|
||||||
|
c10::Synchronized<std::shared_ptr<Weights>> weights_;
|
||||||
|
|
||||||
|
void initialize(
|
||||||
|
std::shared_ptr<Weights> weights,
|
||||||
|
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||||
|
pytorchStreamReader);
|
||||||
|
|
||||||
|
ExecutorFramePtr getExecutorFrameFromPool();
|
||||||
|
void returnExecutorFrameToPool(std::unique_ptr<ExecutionFrame> frame);
|
||||||
|
|
||||||
|
// Clears stale execution frames from the pool
|
||||||
|
void clearStaleExecutionFrames();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Structure to track execution frame usage
|
||||||
|
struct ExecutionFrameEntry {
|
||||||
|
bool used{false};
|
||||||
|
std::unique_ptr<ExecutionFrame> frame;
|
||||||
|
|
||||||
|
// Add move constructor and assignment operator
|
||||||
|
ExecutionFrameEntry() = default;
|
||||||
|
ExecutionFrameEntry(ExecutionFrameEntry&& other) noexcept
|
||||||
|
: used(other.used), frame(std::move(other.frame)) {}
|
||||||
|
ExecutionFrameEntry& operator=(ExecutionFrameEntry&& other) noexcept {
|
||||||
|
used = other.used;
|
||||||
|
frame = std::move(other.frame);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
// Delete copy constructor and assignment operator
|
||||||
|
ExecutionFrameEntry(const ExecutionFrameEntry&) = delete;
|
||||||
|
ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
void maybeRunConstantFolding(std::shared_ptr<Weights> weights);
|
||||||
|
void validateInputs(const std::vector<c10::IValue>& inputs) const;
|
||||||
|
|
||||||
|
// Helper method to get current timestamp in seconds
|
||||||
|
int64_t getCurrentTimestampSeconds() const;
|
||||||
|
|
||||||
|
std::unique_ptr<GraphExecutorBase> graphExecutor_;
|
||||||
|
|
||||||
|
const Placement placement_;
|
||||||
|
|
||||||
|
// NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_.
|
||||||
|
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors_;
|
||||||
|
|
||||||
|
std::vector<ConstFoldingExecution> constFoldingExecutions_;
|
||||||
|
|
||||||
|
std::optional<ConstantFolder> constantFolder_;
|
||||||
|
|
||||||
|
MakeProxyExecutorFn makeProxyExecutorFunc_;
|
||||||
|
|
||||||
|
c10::Semaphore sem_;
|
||||||
|
torch::nativert::detail::MPMCQueue<std::unique_ptr<ExecutionFrame>>
|
||||||
|
executionFrames_;
|
||||||
|
torch::nativert::detail::MPMCQueue<ExecutionFrameEntry>
|
||||||
|
clearedExecutionFrames_;
|
||||||
|
std::atomic_int64_t numExecutionFrames_;
|
||||||
|
|
||||||
|
std::unique_ptr<LayoutPlanner> layoutPlanner_;
|
||||||
|
std::atomic_int64_t lastClearedTimestamp_;
|
||||||
|
std::mutex cleanupLock_;
|
||||||
|
std::atomic_bool clearingInProgress_{false};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace torch::nativert
|
Reference in New Issue
Block a user