Files
pytorch/torch/nativert/executor/ExecutionFrame.h
dolpm 018e9826a2 [nativert] hook up memory planning to execution frame (#157053)
Summary: pretty simple. if planner exists, which implies that planning is enabled, create a manager for each frame. the associated serial executor will use the withMemoryPlannner fn to ensure the deallocation is done after execution completes.

Test Plan: CI

Differential Revision: D73635809

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157053
Approved by: https://github.com/henryoier, https://github.com/georgiaphillips
2025-06-30 00:06:37 +00:00

168 lines
4.3 KiB
C++

#pragma once
#include <unordered_map>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/nativert/executor/ExecutorConfig.h>
#include <torch/nativert/executor/Weights.h>
#include <torch/nativert/executor/memory/LayoutManager.h>
#include <torch/nativert/graph/Graph.h>
#include <c10/util/Logging.h>
namespace torch::nativert {
/**
* This class encapsulate the stateful values of an execution,
* most notably, the tensor values passed between nodes, aka intermediate
* activations.
*/
class ExecutionFrame {
public:
// Constructor for weight-less graph, used for higher order ops, e.g.
// torch.cond
explicit ExecutionFrame(const Graph& graph);
explicit ExecutionFrame(
const Graph& graph,
const Weights& weights,
const torch::nativert::ExecutorConfig& executorConfig = {},
LayoutPlanner* layoutPlanner = nullptr);
// Constructor for testing purpose
explicit ExecutionFrame(
const Graph& graph,
size_t numValues,
const std::vector<ValueId>& graphInputIds,
const std::vector<ValueId>& graphOutputIds);
~ExecutionFrame() {
destroyBorrowedIValues();
}
template <typename CB>
auto withMemoryPlanner(CB&& cb) {
if (!layoutManager_) {
return std::forward<CB>(cb)();
}
LayoutManagerGuard guard(*layoutManager_);
return std::forward<CB>(cb)();
}
std::vector<c10::IValue> tryMoveUserOutputs();
c10::IValue moveIValue(ValueId id) {
return std::move(allValues_[id]);
}
const c10::IValue& getIValue(ValueId id, bool allowNone = true) const {
const auto& iValue = allValues_[id];
if (allowNone && iValue.isNone()) {
return iValue;
}
DCHECK(!iValue.isNone());
return iValue;
}
c10::IValue& getIValue(ValueId id, bool allowNone = true) {
auto& iValue = allValues_[id];
if (allowNone && iValue.isNone()) {
return iValue;
}
DCHECK(!iValue.isNone());
return iValue;
}
void setIValue(ValueId id, c10::IValue ivalue);
void setBorrowedIValue(ValueId id, c10::IValue ivalue);
at::Tensor getTensor(ValueId id) const;
std::vector<at::Tensor> getTensorVector(ValueId id) const {
return getIValue(id).toTensorVector();
}
int64_t getSymInt(ValueId id) const {
return getIValue(id).toInt();
}
double getSymFloat(ValueId id) const {
return getIValue(id).toDouble();
}
const std::vector<bool>& persistentValues() const {
return persistent_;
}
C10_ALWAYS_INLINE bool isManagedValue(const ValueId id) const {
return layoutPlanner_ != nullptr && layoutPlanner_->is_managed(id);
}
void setPersistentIValue(ValueId id, c10::IValue ivalue) {
setIValue(id, std::move(ivalue));
persistent_[id] = true;
}
void releaseValueIfNeeded(ValueId id) {
if (!isManagedValue(id) && !persistent_[id]) {
allValues_[id] = c10::IValue();
}
}
void destroyBorrowedIValues() {
for (const auto& id : borrowedValueIds_) {
c10::MaybeOwnedTraits<c10::IValue>::destroyBorrow(getIValue(id));
}
borrowedValueIds_.clear();
}
void setWork(int64_t workId, const c10::intrusive_ptr<c10d::Work>& work) {
work_[workId] = work;
}
c10::intrusive_ptr<c10d::Work> getWork(int64_t workId) const {
CHECK(work_.find(workId) != work_.end())
<< "Couldn't find work with Id: " << workId;
return work_.at(workId);
}
WeightVersion weightVersion() const {
return weightVersion_;
}
void setWeights(const Weights& weights);
private:
bool isOutputMovable(size_t idx) const {
TORCH_CHECK_LT(idx, moveable_output_mask_.size());
return moveable_output_mask_[idx];
}
void updateMovableOutputs();
const Graph& graph_;
WeightVersion weightVersion_ = -1;
std::unique_ptr<LayoutManager> layoutManager_;
LayoutPlanner* layoutPlanner_{nullptr};
// All the intermediate values for the entire graph, including graph inputs
// and outputs This table is fixed once constructed
std::vector<c10::IValue> allValues_;
std::vector<bool> persistent_;
std::unordered_map<int64_t, c10::intrusive_ptr<c10d::Work>> work_;
std::vector<ValueId> borrowedValueIds_;
std::unordered_map<std::string, ValueId> foldedConstIds_;
// moveable_output_mask_[i] corresponds to user_outputs_[i]
//
// if moveable_output_mask_[i] is true, then user_outputs_[i]
// can be moved
std::vector<bool> moveable_output_mask_;
};
} // namespace torch::nativert