mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Move SerialGraphExecutor to PyTorch core (#156459)
Summary: `SerialGraphExecutor` inherits from `GraphExecutorBase` and executes all nodes in the graph in a serial manner Test Plan: CI Rollback Plan: Differential Revision: D76917966 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156459 Approved by: https://github.com/zhxchen17, https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
a67eb1a0d6
commit
e98dd95446
@ -603,6 +603,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/executor/GraphExecutorBase.cpp",
|
||||
"torch/nativert/executor/OpKernel.cpp",
|
||||
"torch/nativert/executor/PlacementUtils.cpp",
|
||||
"torch/nativert/executor/SerialGraphExecutor.cpp",
|
||||
"torch/nativert/executor/Weights.cpp",
|
||||
"torch/nativert/executor/memory/FunctionSchema.cpp",
|
||||
"torch/nativert/common/FileUtil.cpp",
|
||||
|
33
torch/nativert/executor/SerialGraphExecutor.cpp
Normal file
33
torch/nativert/executor/SerialGraphExecutor.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
#include <torch/nativert/executor/ExecutionPlanner.h>
|
||||
#include <torch/nativert/executor/ExecutorConfig.h>
|
||||
#include <torch/nativert/executor/SerialGraphExecutor.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
std::vector<c10::IValue> SerialGraphExecutor::execute(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<c10::IValue> inputs) {
|
||||
fillUserInputs(executionFrame, std::move(inputs));
|
||||
|
||||
return executeWithPrefilledFrame(executionFrame);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> SerialGraphExecutor::executeWithPrefilledFrame(
|
||||
ExecutionFrame& executionFrame) {
|
||||
// Execute kernels for all nodes except prim.Input and prim.Output
|
||||
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
|
||||
nodeKernels_[nodeIdx]->compute(executionFrame);
|
||||
|
||||
// don't free intermediate values when static memory planning is enabled
|
||||
if (!executorConfig_.enableStaticMemoryPlanning) {
|
||||
// Free the intermediate values that are no used anymore
|
||||
for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) {
|
||||
executionFrame.releaseValue(valueKey);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return executionFrame.tryMoveUserOutputs();
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
23
torch/nativert/executor/SerialGraphExecutor.h
Normal file
23
torch/nativert/executor/SerialGraphExecutor.h
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nativert/executor/GraphExecutorBase.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class SerialGraphExecutor : public GraphExecutorBase {
|
||||
public:
|
||||
SerialGraphExecutor(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const ExecutorConfig& executorConfig)
|
||||
: GraphExecutorBase(graph, std::move(nodeKernels), executorConfig) {}
|
||||
|
||||
std::vector<c10::IValue> execute(
|
||||
ExecutionFrame& frame,
|
||||
std::vector<c10::IValue> inputs) override;
|
||||
|
||||
std::vector<c10::IValue> executeWithPrefilledFrame(
|
||||
ExecutionFrame& frame) override;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user