From 262654ee518eb314678f53baf4e133e4767eca3d Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Thu, 26 Jun 2025 21:26:33 +0000 Subject: [PATCH] [nativert] move constantfolder to libtorch (#156918) Summary: att -- unit tests will be migrated later, since they still have unresolved deps. Test Plan: ci Rollback Plan: Differential Revision: D77159278 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156918 Approved by: https://github.com/henryoier, https://github.com/zhxchen17 --- build_variables.bzl | 1 + torch/nativert/executor/ConstantFolder.cpp | 167 +++++++++++++++++++++ torch/nativert/executor/ConstantFolder.h | 53 +++++++ 3 files changed, 221 insertions(+) create mode 100644 torch/nativert/executor/ConstantFolder.cpp create mode 100644 torch/nativert/executor/ConstantFolder.h diff --git a/build_variables.bzl b/build_variables.bzl index da49ed05dada..296260912b81 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -602,6 +602,7 @@ libtorch_nativert_sources = [ "torch/nativert/executor/ExecutionPlanner.cpp", "torch/nativert/executor/ExecutionFrame.cpp", "torch/nativert/executor/GraphExecutorBase.cpp", + "torch/nativert/executor/ConstantFolder.cpp", "torch/nativert/executor/OpKernel.cpp", "torch/nativert/executor/PlacementUtils.cpp", "torch/nativert/executor/SerialGraphExecutor.cpp", diff --git a/torch/nativert/executor/ConstantFolder.cpp b/torch/nativert/executor/ConstantFolder.cpp new file mode 100644 index 000000000000..177bcc82833a --- /dev/null +++ b/torch/nativert/executor/ConstantFolder.cpp @@ -0,0 +1,167 @@ +#include + +#include +#include + +#include + +#include +#include + +namespace torch::nativert { + +/* + side effects: + 1. nodes deemed const-foldable nodes are unlinked from the graph. + they are still owned by the graph (i.e., show up in graph.nodeOwner_) + but are not accessible through the node iterator. + + 2. kernels associated with const-foldable nodes are removed from the + 'kernels' input + + 3. mark values deemed foldable as such, removing thier producers +*/ + +void ConstantFolder::unlinkConstants( + std::vector>& kernels) { + TORCH_CHECK_EQ(kernels.size(), graph_.nodes().size()) + << "graph node count and kernel count should be equal"; + + unlinked_ = true; + + /* resolve all of the nodes that are const foldable */ + + c10::FastMap nodeDynInputs; + nodeDynInputs.reserve(graph_.nodes().size()); + + c10::FastMap*> nodeKernels; + nodeKernels.reserve(graph_.nodes().size()); + + const auto* input = &*graph_.nodes().begin(); + const auto* output = &*graph_.nodes().end(); + + { // ignore prim.Input and prim.Output + auto ct = 0; + for (auto& n : graph_.nodes()) { + if (&n == input || &n == output) { + continue; + } + nodeDynInputs[&n] = n.numInputs(); + nodeKernels[&n] = &kernels[++ct]; + } + } + + const auto& inputsToWeights = graph_.signature().inputsToWeights(); + for (const auto& [inputName, weightName] : inputsToWeights) { + for (auto* user : graph_.getValue(inputName)->users()) { + if (user == input || user == output) { + continue; + } + nodeDynInputs[user] -= 1; + } + } + + // set of foldable nodes for dedupe purposes + c10::FastSet foldable; + + std::queue constFoldableCandidates; + for (auto& [node, ct] : nodeDynInputs) { + if (ct++ /* will be decremented once dequeued */ == 0) { + constFoldableCandidates.push(node); + } + } + + while (!constFoldableCandidates.empty()) { + auto* candidate = constFoldableCandidates.front(); + constFoldableCandidates.pop(); + if (auto& ct = nodeDynInputs[candidate]; --ct == 0) { + foldable.insert(candidate); + Foldable f; + f.node = candidate; + f.kernel = std::move(*nodeKernels[candidate]); + foldables_.push_back(std::move(f)); + + candidate->unlink(); + + for (auto* user : candidate->users()) { + if (user == output) { + continue; + } + if (foldable.find(user) == foldable.end()) { + constFoldableCandidates.push(user); + } + } + + for (auto* out : candidate->outputs()) { + auto* value = graph_.getValue(out->name()); + + value->setIsFolded(); + + // we only store folded values if there is a non-foldable user + if (const auto& users = value->users(); + std::any_of(users.begin(), users.end(), [&](const auto* u) { + return foldable.find(u) == foldable.end(); + })) { + foldedOutputValueIds_.insert(value->id()); + } + } + } + } + + for (const auto& f : foldables_) { + VLOG(1) << "Const-folded node: " << *f.node; + } + + // remove moved (i.e., associated w/ const-folded nodes) kernels + // from the input kernel vector + kernels.erase( + std::remove_if( + kernels.begin(), + kernels.end(), + [](const auto& k) { return k == nullptr; }), + kernels.end()); + + graph_.renumberValues(); + graph_.finalize(); + graph_.lint(); + + return; +} + +/* + side effects: + 1. weights whose users are ONLY const-foldable nodes will be removed + from the 'weights' input +*/ + +void ConstantFolder::evaluate(Weights& weights) { + CHECK(unlinked_) + << "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants"; + + weights.validateAllWeightsLoaded(); + + ExecutionFrame frame(graph_); + frame.setWeights(weights); + + c10::FastMap foldedValues; + + for (const auto& f : foldables_) { + f.kernel->compute(frame); + + for (auto&& [i, out] : c10::enumerate(f.node->outputs())) { + if (foldedOutputValueIds_.find(out->id()) != + foldedOutputValueIds_.end()) { + foldedValues[std::string{out->name()}] = f.kernel->output(i, frame); + } + } + } + + for (auto it = std::make_move_iterator(foldedValues.begin()); + it != std::make_move_iterator(foldedValues.end()); + ++it) { + auto [n, iv] = std::move(*it); + weights.setConstFoldedValue(n, std::move(iv)); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ConstantFolder.h b/torch/nativert/executor/ConstantFolder.h new file mode 100644 index 000000000000..b1d1afa12f4f --- /dev/null +++ b/torch/nativert/executor/ConstantFolder.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::nativert { + +struct Foldable { + Node* node; + std::unique_ptr kernel; +}; + +class ConstantFolder { + public: + explicit ConstantFolder(Graph& graph) : graph_(graph) {} + + /* + 1. identify nodes without dynamic inputs, mark as foldable + + 2. traverse the nodes deemed foldable as if they were being evaluated, + pushing nodes that become foldable after it's inputs were traversed. + + unlink foldable nodes from the graph in the topological order in which + they were traversed, storing the node and its associated kernel (moved + from 'kernels') as a foldable in Constantfolder + */ + void unlinkConstants( + /* kernels for const-foldable nodes will be removed from this vector */ + std::vector>& kernels); + + /* + 1. execute foldables_ on an execution frame initialized with the passed-in + weights, calling Weights::setConstFoldedValue if the folded value is + consumed by a non-foldable node + */ + void evaluate(Weights& weights); + + private: + Graph& graph_; + // unlinked nodes sorted in their topological order + // s.t., they can be evaluated sequentially + std::vector foldables_; + + bool unlinked_{false}; + + c10::FastSet foldedOutputValueIds_; +}; + +} // namespace torch::nativert