[nativert] hook up memory planning to execution frame (#157053)

Summary: pretty simple. if planner exists, which implies that planning is enabled, create a manager for each frame. the associated serial executor will use the withMemoryPlannner fn to ensure the deallocation is done after execution completes.

Test Plan: CI

Differential Revision: D73635809

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157053
Approved by: https://github.com/henryoier, https://github.com/georgiaphillips
This commit is contained in:
dolpm
2025-06-30 00:06:37 +00:00
committed by PyTorch MergeBot
parent 41f6acef83
commit 018e9826a2
8 changed files with 70 additions and 25 deletions

View File

@ -21,6 +21,9 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/executor/memory/GreedyBySize.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/Bump.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/DisjointStorageGroups.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp
)
add_executable(test_nativert

View File

@ -90,7 +90,9 @@ TEST(ExecutionFrameTest, TestPersistentValue) {
auto wid = graph->getValue("my_weight")->id();
EXPECT_NO_THROW(frame.getTensor(wid));
EXPECT_DEATH(frame.releaseValue(wid), "Cannot release persistent value");
// can't release persistent value
frame.releaseValueIfNeeded(wid);
EXPECT_FALSE(frame.getIValue(wid).isNone());
}
} // namespace torch::nativert

View File

@ -29,9 +29,20 @@ ExecutionFrame::ExecutionFrame(const Graph& graph)
}
}
ExecutionFrame::ExecutionFrame(const Graph& graph, const Weights& weights)
ExecutionFrame::ExecutionFrame(
const Graph& graph,
const Weights& weights,
const torch::nativert::ExecutorConfig& cfg,
LayoutPlanner* layoutPlanner)
: ExecutionFrame(graph) {
setWeights(weights);
if (layoutPlanner != nullptr) {
layoutPlanner_ = layoutPlanner;
layoutManager_ = std::make_unique<LayoutManager>(
*layoutPlanner,
*this,
cfg.layoutPlannerSettings.layoutManagerSettings());
}
}
void ExecutionFrame::setWeights(const Weights& weights) {

View File

@ -3,7 +3,9 @@
#include <unordered_map>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/nativert/executor/ExecutorConfig.h>
#include <torch/nativert/executor/Weights.h>
#include <torch/nativert/executor/memory/LayoutManager.h>
#include <torch/nativert/graph/Graph.h>
#include <c10/util/Logging.h>
@ -21,7 +23,11 @@ class ExecutionFrame {
// torch.cond
explicit ExecutionFrame(const Graph& graph);
explicit ExecutionFrame(const Graph& graph, const Weights& weights);
explicit ExecutionFrame(
const Graph& graph,
const Weights& weights,
const torch::nativert::ExecutorConfig& executorConfig = {},
LayoutPlanner* layoutPlanner = nullptr);
// Constructor for testing purpose
explicit ExecutionFrame(
@ -34,6 +40,16 @@ class ExecutionFrame {
destroyBorrowedIValues();
}
template <typename CB>
auto withMemoryPlanner(CB&& cb) {
if (!layoutManager_) {
return std::forward<CB>(cb)();
}
LayoutManagerGuard guard(*layoutManager_);
return std::forward<CB>(cb)();
}
std::vector<c10::IValue> tryMoveUserOutputs();
c10::IValue moveIValue(ValueId id) {
@ -79,14 +95,19 @@ class ExecutionFrame {
return persistent_;
}
C10_ALWAYS_INLINE bool isManagedValue(const ValueId id) const {
return layoutPlanner_ != nullptr && layoutPlanner_->is_managed(id);
}
void setPersistentIValue(ValueId id, c10::IValue ivalue) {
setIValue(id, std::move(ivalue));
persistent_[id] = true;
}
void releaseValue(ValueId id) {
CHECK(!persistent_[id]) << "Cannot release persistent value";
allValues_[id] = c10::IValue();
void releaseValueIfNeeded(ValueId id) {
if (!isManagedValue(id) && !persistent_[id]) {
allValues_[id] = c10::IValue();
}
}
void destroyBorrowedIValues() {
@ -122,6 +143,9 @@ class ExecutionFrame {
const Graph& graph_;
WeightVersion weightVersion_ = -1;
std::unique_ptr<LayoutManager> layoutManager_;
LayoutPlanner* layoutPlanner_{nullptr};
// All the intermediate values for the entire graph, including graph inputs
// and outputs This table is fixed once constructed
std::vector<c10::IValue> allValues_;

View File

@ -14,19 +14,20 @@ std::vector<c10::IValue> SerialGraphExecutor::execute(
std::vector<c10::IValue> SerialGraphExecutor::executeWithPrefilledFrame(
ExecutionFrame& executionFrame) {
// Execute kernels for all nodes except prim.Input and prim.Output
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
nodeKernels_[nodeIdx]->compute(executionFrame);
executionFrame.withMemoryPlanner([&]() {
// Execute kernels for all nodes except prim.Input and prim.Output
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
nodeKernels_[nodeIdx]->compute(executionFrame);
// don't free intermediate values when static memory planning is enabled
if (executorConfig_.tryFreeUnmanagedValuesAfterUse) {
// Free the intermediate values that are no used anymore
for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) {
executionFrame.releaseValue(valueKey);
// don't free intermediate values when static memory planning is enabled
if (executorConfig_.tryFreeUnmanagedValuesAfterUse) {
// Free the intermediate values that are no used anymore
for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) {
executionFrame.releaseValueIfNeeded(valueKey);
}
}
}
}
});
return executionFrame.tryMoveUserOutputs();
}

View File

@ -162,12 +162,13 @@ void AliasAnalyzer::log_state() const {
for (const auto* a : alias) {
ss << a->name() << ", ";
}
ss << "\n";
ss << '\n';
}
ss << '\n';
return ss.str();
}() << std::endl
<< std::flush;
}() << std::flush;
}
} // namespace torch::nativert

View File

@ -64,6 +64,7 @@ void LayoutManager::allocate_plan(const LayoutPlan& plan) {
void* offset_ptr =
layout_buffer_.get_ptr_with_offset(planned_allocation.offset);
// NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object)
auto& storage = storage_buf[i];
// if the existing data ptr doesn't have an associated deleter then we
@ -124,12 +125,15 @@ void LayoutManager::ensure_managed_storages(bool allocate) {
} else if (
C10_UNLIKELY(
&storage !=
// NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object)
&storage_buf
[i]) /* managed storage was replaced for some reason */) {
storage.reset();
tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
c10::intrusive_ptr<at::StorageImpl>::unsafe_adapt_non_heap_allocated(
&storage_buf[i], 1)));
// NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object)
&storage_buf[i],
1)));
}
}
}

View File

@ -80,7 +80,7 @@ LayoutPlanner::LayoutPlanner(
continue;
}
if (bool is_consumed = output->users().size() > 0; !is_consumed) {
if (bool is_not_consumed = output->users().empty(); is_not_consumed) {
VLOG(1) << "not planning " << output->name() << " as it has no users";
continue;
}
@ -154,7 +154,7 @@ void LayoutPlanner::initialize_vectors(
planned_values_[i] = v->id();
planned_values_historical_max_nbytes_[i] = spec.size;
planned_allocation_specs_[i] = std::move(spec);
planned_allocation_specs_[i] = spec;
i++;
}
@ -178,9 +178,8 @@ void LayoutPlanner::start_worker_if_not_started() {
// make sure plan is populated by the time this
// returns for the first time :P
create_plan();
worker_ = std::thread([this]() {
run_periodic(std::bind(&LayoutPlanner::create_plan, this));
});
worker_ =
std::thread([this]() { run_periodic([this] { create_plan(); }); });
});
}