mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[nativert] hook up memory planning to execution frame (#157053)
Summary: pretty simple. if planner exists, which implies that planning is enabled, create a manager for each frame. the associated serial executor will use the withMemoryPlannner fn to ensure the deallocation is done after execution completes. Test Plan: CI Differential Revision: D73635809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157053 Approved by: https://github.com/henryoier, https://github.com/georgiaphillips
This commit is contained in:
@ -21,6 +21,9 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/GreedyBySize.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/Bump.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/DisjointStorageGroups.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp
|
||||
)
|
||||
|
||||
add_executable(test_nativert
|
||||
|
@ -90,7 +90,9 @@ TEST(ExecutionFrameTest, TestPersistentValue) {
|
||||
auto wid = graph->getValue("my_weight")->id();
|
||||
|
||||
EXPECT_NO_THROW(frame.getTensor(wid));
|
||||
EXPECT_DEATH(frame.releaseValue(wid), "Cannot release persistent value");
|
||||
// can't release persistent value
|
||||
frame.releaseValueIfNeeded(wid);
|
||||
EXPECT_FALSE(frame.getIValue(wid).isNone());
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
@ -29,9 +29,20 @@ ExecutionFrame::ExecutionFrame(const Graph& graph)
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionFrame::ExecutionFrame(const Graph& graph, const Weights& weights)
|
||||
ExecutionFrame::ExecutionFrame(
|
||||
const Graph& graph,
|
||||
const Weights& weights,
|
||||
const torch::nativert::ExecutorConfig& cfg,
|
||||
LayoutPlanner* layoutPlanner)
|
||||
: ExecutionFrame(graph) {
|
||||
setWeights(weights);
|
||||
if (layoutPlanner != nullptr) {
|
||||
layoutPlanner_ = layoutPlanner;
|
||||
layoutManager_ = std::make_unique<LayoutManager>(
|
||||
*layoutPlanner,
|
||||
*this,
|
||||
cfg.layoutPlannerSettings.layoutManagerSettings());
|
||||
}
|
||||
}
|
||||
|
||||
void ExecutionFrame::setWeights(const Weights& weights) {
|
||||
|
@ -3,7 +3,9 @@
|
||||
#include <unordered_map>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
#include <torch/nativert/executor/ExecutorConfig.h>
|
||||
#include <torch/nativert/executor/Weights.h>
|
||||
#include <torch/nativert/executor/memory/LayoutManager.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
@ -21,7 +23,11 @@ class ExecutionFrame {
|
||||
// torch.cond
|
||||
explicit ExecutionFrame(const Graph& graph);
|
||||
|
||||
explicit ExecutionFrame(const Graph& graph, const Weights& weights);
|
||||
explicit ExecutionFrame(
|
||||
const Graph& graph,
|
||||
const Weights& weights,
|
||||
const torch::nativert::ExecutorConfig& executorConfig = {},
|
||||
LayoutPlanner* layoutPlanner = nullptr);
|
||||
|
||||
// Constructor for testing purpose
|
||||
explicit ExecutionFrame(
|
||||
@ -34,6 +40,16 @@ class ExecutionFrame {
|
||||
destroyBorrowedIValues();
|
||||
}
|
||||
|
||||
template <typename CB>
|
||||
auto withMemoryPlanner(CB&& cb) {
|
||||
if (!layoutManager_) {
|
||||
return std::forward<CB>(cb)();
|
||||
}
|
||||
|
||||
LayoutManagerGuard guard(*layoutManager_);
|
||||
return std::forward<CB>(cb)();
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> tryMoveUserOutputs();
|
||||
|
||||
c10::IValue moveIValue(ValueId id) {
|
||||
@ -79,14 +95,19 @@ class ExecutionFrame {
|
||||
return persistent_;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE bool isManagedValue(const ValueId id) const {
|
||||
return layoutPlanner_ != nullptr && layoutPlanner_->is_managed(id);
|
||||
}
|
||||
|
||||
void setPersistentIValue(ValueId id, c10::IValue ivalue) {
|
||||
setIValue(id, std::move(ivalue));
|
||||
persistent_[id] = true;
|
||||
}
|
||||
|
||||
void releaseValue(ValueId id) {
|
||||
CHECK(!persistent_[id]) << "Cannot release persistent value";
|
||||
allValues_[id] = c10::IValue();
|
||||
void releaseValueIfNeeded(ValueId id) {
|
||||
if (!isManagedValue(id) && !persistent_[id]) {
|
||||
allValues_[id] = c10::IValue();
|
||||
}
|
||||
}
|
||||
|
||||
void destroyBorrowedIValues() {
|
||||
@ -122,6 +143,9 @@ class ExecutionFrame {
|
||||
const Graph& graph_;
|
||||
WeightVersion weightVersion_ = -1;
|
||||
|
||||
std::unique_ptr<LayoutManager> layoutManager_;
|
||||
LayoutPlanner* layoutPlanner_{nullptr};
|
||||
|
||||
// All the intermediate values for the entire graph, including graph inputs
|
||||
// and outputs This table is fixed once constructed
|
||||
std::vector<c10::IValue> allValues_;
|
||||
|
@ -14,19 +14,20 @@ std::vector<c10::IValue> SerialGraphExecutor::execute(
|
||||
|
||||
std::vector<c10::IValue> SerialGraphExecutor::executeWithPrefilledFrame(
|
||||
ExecutionFrame& executionFrame) {
|
||||
// Execute kernels for all nodes except prim.Input and prim.Output
|
||||
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
|
||||
nodeKernels_[nodeIdx]->compute(executionFrame);
|
||||
executionFrame.withMemoryPlanner([&]() {
|
||||
// Execute kernels for all nodes except prim.Input and prim.Output
|
||||
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
|
||||
nodeKernels_[nodeIdx]->compute(executionFrame);
|
||||
|
||||
// don't free intermediate values when static memory planning is enabled
|
||||
if (executorConfig_.tryFreeUnmanagedValuesAfterUse) {
|
||||
// Free the intermediate values that are no used anymore
|
||||
for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) {
|
||||
executionFrame.releaseValue(valueKey);
|
||||
// don't free intermediate values when static memory planning is enabled
|
||||
if (executorConfig_.tryFreeUnmanagedValuesAfterUse) {
|
||||
// Free the intermediate values that are no used anymore
|
||||
for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) {
|
||||
executionFrame.releaseValueIfNeeded(valueKey);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
return executionFrame.tryMoveUserOutputs();
|
||||
}
|
||||
|
||||
|
@ -162,12 +162,13 @@ void AliasAnalyzer::log_state() const {
|
||||
for (const auto* a : alias) {
|
||||
ss << a->name() << ", ";
|
||||
}
|
||||
ss << "\n";
|
||||
ss << '\n';
|
||||
}
|
||||
|
||||
ss << '\n';
|
||||
|
||||
return ss.str();
|
||||
}() << std::endl
|
||||
<< std::flush;
|
||||
}() << std::flush;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
@ -64,6 +64,7 @@ void LayoutManager::allocate_plan(const LayoutPlan& plan) {
|
||||
|
||||
void* offset_ptr =
|
||||
layout_buffer_.get_ptr_with_offset(planned_allocation.offset);
|
||||
// NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object)
|
||||
auto& storage = storage_buf[i];
|
||||
|
||||
// if the existing data ptr doesn't have an associated deleter then we
|
||||
@ -124,12 +125,15 @@ void LayoutManager::ensure_managed_storages(bool allocate) {
|
||||
} else if (
|
||||
C10_UNLIKELY(
|
||||
&storage !=
|
||||
// NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object)
|
||||
&storage_buf
|
||||
[i]) /* managed storage was replaced for some reason */) {
|
||||
storage.reset();
|
||||
tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
|
||||
c10::intrusive_ptr<at::StorageImpl>::unsafe_adapt_non_heap_allocated(
|
||||
&storage_buf[i], 1)));
|
||||
// NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object)
|
||||
&storage_buf[i],
|
||||
1)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ LayoutPlanner::LayoutPlanner(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bool is_consumed = output->users().size() > 0; !is_consumed) {
|
||||
if (bool is_not_consumed = output->users().empty(); is_not_consumed) {
|
||||
VLOG(1) << "not planning " << output->name() << " as it has no users";
|
||||
continue;
|
||||
}
|
||||
@ -154,7 +154,7 @@ void LayoutPlanner::initialize_vectors(
|
||||
|
||||
planned_values_[i] = v->id();
|
||||
planned_values_historical_max_nbytes_[i] = spec.size;
|
||||
planned_allocation_specs_[i] = std::move(spec);
|
||||
planned_allocation_specs_[i] = spec;
|
||||
|
||||
i++;
|
||||
}
|
||||
@ -178,9 +178,8 @@ void LayoutPlanner::start_worker_if_not_started() {
|
||||
// make sure plan is populated by the time this
|
||||
// returns for the first time :P
|
||||
create_plan();
|
||||
worker_ = std::thread([this]() {
|
||||
run_periodic(std::bind(&LayoutPlanner::create_plan, this));
|
||||
});
|
||||
worker_ =
|
||||
std::thread([this]() { run_periodic([this] { create_plan(); }); });
|
||||
});
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user