[nativert] move constantfolder to libtorch (#156918)

Summary: att -- unit tests will be migrated later, since they still have unresolved deps.

Test Plan:
ci

Rollback Plan:

Differential Revision: D77159278

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156918
Approved by: https://github.com/henryoier, https://github.com/zhxchen17
This commit is contained in:
dolpm
2025-06-26 21:26:33 +00:00
committed by PyTorch MergeBot
parent 7f6e7103a3
commit 262654ee51
3 changed files with 221 additions and 0 deletions

View File

@ -602,6 +602,7 @@ libtorch_nativert_sources = [
"torch/nativert/executor/ExecutionPlanner.cpp",
"torch/nativert/executor/ExecutionFrame.cpp",
"torch/nativert/executor/GraphExecutorBase.cpp",
"torch/nativert/executor/ConstantFolder.cpp",
"torch/nativert/executor/OpKernel.cpp",
"torch/nativert/executor/PlacementUtils.cpp",
"torch/nativert/executor/SerialGraphExecutor.cpp",

View File

@ -0,0 +1,167 @@
#include <torch/nativert/executor/ConstantFolder.h>
#include <algorithm>
#include <queue>
#include <c10/util/Enumerate.h>
#include <torch/nativert/executor/DelegateExecutor.h>
#include <torch/nativert/executor/Weights.h>
namespace torch::nativert {
/*
side effects:
1. nodes deemed const-foldable nodes are unlinked from the graph.
they are still owned by the graph (i.e., show up in graph.nodeOwner_)
but are not accessible through the node iterator.
2. kernels associated with const-foldable nodes are removed from the
'kernels' input
3. mark values deemed foldable as such, removing thier producers
*/
void ConstantFolder::unlinkConstants(
std::vector<std::unique_ptr<OpKernel>>& kernels) {
TORCH_CHECK_EQ(kernels.size(), graph_.nodes().size())
<< "graph node count and kernel count should be equal";
unlinked_ = true;
/* resolve all of the nodes that are const foldable */
c10::FastMap<Node*, uint32_t> nodeDynInputs;
nodeDynInputs.reserve(graph_.nodes().size());
c10::FastMap<const Node*, std::unique_ptr<OpKernel>*> nodeKernels;
nodeKernels.reserve(graph_.nodes().size());
const auto* input = &*graph_.nodes().begin();
const auto* output = &*graph_.nodes().end();
{ // ignore prim.Input and prim.Output
auto ct = 0;
for (auto& n : graph_.nodes()) {
if (&n == input || &n == output) {
continue;
}
nodeDynInputs[&n] = n.numInputs();
nodeKernels[&n] = &kernels[++ct];
}
}
const auto& inputsToWeights = graph_.signature().inputsToWeights();
for (const auto& [inputName, weightName] : inputsToWeights) {
for (auto* user : graph_.getValue(inputName)->users()) {
if (user == input || user == output) {
continue;
}
nodeDynInputs[user] -= 1;
}
}
// set of foldable nodes for dedupe purposes
c10::FastSet<const Node*> foldable;
std::queue<Node*> constFoldableCandidates;
for (auto& [node, ct] : nodeDynInputs) {
if (ct++ /* will be decremented once dequeued */ == 0) {
constFoldableCandidates.push(node);
}
}
while (!constFoldableCandidates.empty()) {
auto* candidate = constFoldableCandidates.front();
constFoldableCandidates.pop();
if (auto& ct = nodeDynInputs[candidate]; --ct == 0) {
foldable.insert(candidate);
Foldable f;
f.node = candidate;
f.kernel = std::move(*nodeKernels[candidate]);
foldables_.push_back(std::move(f));
candidate->unlink();
for (auto* user : candidate->users()) {
if (user == output) {
continue;
}
if (foldable.find(user) == foldable.end()) {
constFoldableCandidates.push(user);
}
}
for (auto* out : candidate->outputs()) {
auto* value = graph_.getValue(out->name());
value->setIsFolded();
// we only store folded values if there is a non-foldable user
if (const auto& users = value->users();
std::any_of(users.begin(), users.end(), [&](const auto* u) {
return foldable.find(u) == foldable.end();
})) {
foldedOutputValueIds_.insert(value->id());
}
}
}
}
for (const auto& f : foldables_) {
VLOG(1) << "Const-folded node: " << *f.node;
}
// remove moved (i.e., associated w/ const-folded nodes) kernels
// from the input kernel vector
kernels.erase(
std::remove_if(
kernels.begin(),
kernels.end(),
[](const auto& k) { return k == nullptr; }),
kernels.end());
graph_.renumberValues();
graph_.finalize();
graph_.lint();
return;
}
/*
side effects:
1. weights whose users are ONLY const-foldable nodes will be removed
from the 'weights' input
*/
void ConstantFolder::evaluate(Weights& weights) {
CHECK(unlinked_)
<< "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants";
weights.validateAllWeightsLoaded();
ExecutionFrame frame(graph_);
frame.setWeights(weights);
c10::FastMap<std::string, c10::IValue> foldedValues;
for (const auto& f : foldables_) {
f.kernel->compute(frame);
for (auto&& [i, out] : c10::enumerate(f.node->outputs())) {
if (foldedOutputValueIds_.find(out->id()) !=
foldedOutputValueIds_.end()) {
foldedValues[std::string{out->name()}] = f.kernel->output(i, frame);
}
}
}
for (auto it = std::make_move_iterator(foldedValues.begin());
it != std::make_move_iterator(foldedValues.end());
++it) {
auto [n, iv] = std::move(*it);
weights.setConstFoldedValue(n, std::move(iv));
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,53 @@
#pragma once
#include <memory>
#include <vector>
#include <torch/nativert/executor/OpKernel.h>
#include <torch/nativert/executor/Weights.h>
#include <torch/nativert/graph/Graph.h>
namespace torch::nativert {
struct Foldable {
Node* node;
std::unique_ptr<OpKernel> kernel;
};
class ConstantFolder {
public:
explicit ConstantFolder(Graph& graph) : graph_(graph) {}
/*
1. identify nodes without dynamic inputs, mark as foldable
2. traverse the nodes deemed foldable as if they were being evaluated,
pushing nodes that become foldable after it's inputs were traversed.
unlink foldable nodes from the graph in the topological order in which
they were traversed, storing the node and its associated kernel (moved
from 'kernels') as a foldable in Constantfolder
*/
void unlinkConstants(
/* kernels for const-foldable nodes will be removed from this vector */
std::vector<std::unique_ptr<OpKernel>>& kernels);
/*
1. execute foldables_ on an execution frame initialized with the passed-in
weights, calling Weights::setConstFoldedValue if the folded value is
consumed by a non-foldable node
*/
void evaluate(Weights& weights);
private:
Graph& graph_;
// unlinked nodes sorted in their topological order
// s.t., they can be evaluated sequentially
std::vector<Foldable> foldables_;
bool unlinked_{false};
c10::FastSet<ValueId> foldedOutputValueIds_;
};
} // namespace torch::nativert