From e98dd95446e009ace1722498effbf32250d623e4 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Sat, 21 Jun 2025 01:32:02 +0000 Subject: [PATCH] [nativert] Move SerialGraphExecutor to PyTorch core (#156459) Summary: `SerialGraphExecutor` inherits from `GraphExecutorBase` and executes all nodes in the graph in a serial manner Test Plan: CI Rollback Plan: Differential Revision: D76917966 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156459 Approved by: https://github.com/zhxchen17, https://github.com/jingsh --- build_variables.bzl | 1 + .../nativert/executor/SerialGraphExecutor.cpp | 33 +++++++++++++++++++ torch/nativert/executor/SerialGraphExecutor.h | 23 +++++++++++++ 3 files changed, 57 insertions(+) create mode 100644 torch/nativert/executor/SerialGraphExecutor.cpp create mode 100644 torch/nativert/executor/SerialGraphExecutor.h diff --git a/build_variables.bzl b/build_variables.bzl index 0c7442d36262..77e6ab46837b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -603,6 +603,7 @@ libtorch_nativert_sources = [ "torch/nativert/executor/GraphExecutorBase.cpp", "torch/nativert/executor/OpKernel.cpp", "torch/nativert/executor/PlacementUtils.cpp", + "torch/nativert/executor/SerialGraphExecutor.cpp", "torch/nativert/executor/Weights.cpp", "torch/nativert/executor/memory/FunctionSchema.cpp", "torch/nativert/common/FileUtil.cpp", diff --git a/torch/nativert/executor/SerialGraphExecutor.cpp b/torch/nativert/executor/SerialGraphExecutor.cpp new file mode 100644 index 000000000000..f1ef0491eda1 --- /dev/null +++ b/torch/nativert/executor/SerialGraphExecutor.cpp @@ -0,0 +1,33 @@ +#include +#include +#include + +namespace torch::nativert { + +std::vector SerialGraphExecutor::execute( + ExecutionFrame& executionFrame, + std::vector inputs) { + fillUserInputs(executionFrame, std::move(inputs)); + + return executeWithPrefilledFrame(executionFrame); +} + +std::vector SerialGraphExecutor::executeWithPrefilledFrame( + ExecutionFrame& executionFrame) { + // Execute kernels for all nodes except prim.Input and prim.Output + for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { + nodeKernels_[nodeIdx]->compute(executionFrame); + + // don't free intermediate values when static memory planning is enabled + if (!executorConfig_.enableStaticMemoryPlanning) { + // Free the intermediate values that are no used anymore + for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) { + executionFrame.releaseValue(valueKey); + } + } + } + + return executionFrame.tryMoveUserOutputs(); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/SerialGraphExecutor.h b/torch/nativert/executor/SerialGraphExecutor.h new file mode 100644 index 000000000000..cae3313e61e8 --- /dev/null +++ b/torch/nativert/executor/SerialGraphExecutor.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace torch::nativert { + +class SerialGraphExecutor : public GraphExecutorBase { + public: + SerialGraphExecutor( + const Graph& graph, + std::vector> nodeKernels, + const ExecutorConfig& executorConfig) + : GraphExecutorBase(graph, std::move(nodeKernels), executorConfig) {} + + std::vector execute( + ExecutionFrame& frame, + std::vector inputs) override; + + std::vector executeWithPrefilledFrame( + ExecutionFrame& frame) override; +}; + +} // namespace torch::nativert