[nativert] alias analyzer + layout planner/manager to pytorch core (#156897)

Summary: att

Test Plan:
ci - unit tests still have some unresolved deps but will move them later.

Rollback Plan:

Differential Revision: D77320950

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156897
Approved by: https://github.com/zhxchen17
This commit is contained in:
dolpm
2025-06-27 03:01:22 +00:00
committed by PyTorch MergeBot
parent 382c6190c1
commit 7392470da4
7 changed files with 1002 additions and 0 deletions

View File

@ -619,6 +619,9 @@ libtorch_nativert_sources = [
"torch/nativert/kernels/CallTorchBindKernel.cpp",
"torch/nativert/kernels/PrimKernelRegistry.cpp",
"torch/nativert/executor/memory/DisjointStorageGroups.cpp",
"torch/nativert/executor/memory/AliasAnalyzer.cpp",
"torch/nativert/executor/memory/LayoutPlanner.cpp",
"torch/nativert/executor/memory/LayoutManager.cpp",
]
torch_mobile_tracer_sources = [

View File

@ -0,0 +1,173 @@
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
#include <c10/util/Enumerate.h>
namespace torch::nativert {
AliasAnalyzer::AliasAnalyzer(
const Graph& graph,
const c10::FastMap<std::string /* target */, FunctionSchema>& schemas) {
for (const auto&& [i, node] : c10::enumerate(graph.nodes())) {
for (const auto& input : node.inputs()) {
create_or_update_lifetime(input.value, i);
}
for (const auto& output : node.outputs()) {
create_or_update_lifetime(output, i);
}
if (update_aliases_if_packed_listunpack(node, i) /* applied? */) {
continue;
}
maybe_update_aliases_from_schema(node, schemas);
}
// set all non-aliasing outputs. outputs
// that are aliased will be set later when
// lifetimes are extended
for (const auto* output : graph.outputs()) {
if (!is_alias(output)) {
values_associated_with_outputs_.insert(output);
}
}
maybe_extend_lifetimes(graph);
log_state();
}
bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack(
const Node& node,
size_t i) {
if (node.target() != "prim.ListUnpack") {
return false;
}
const auto* list = node.inputs()[0].value;
// we can't infer about how this list was made in this case
// so fallback to default always-aliasing behaviour
if (const auto* p = list->producer(); p && p->target() != "prim.ListPack") {
return false;
}
const auto& list_elems = list->getListElements();
TORCH_CHECK_EQ(list_elems.size(), node.numOutputs());
for (const auto j : c10::irange(node.numOutputs())) {
const Value* input = list_elems.at(j);
const Value* output = node.outputs().at(j);
TORCH_CHECK_NE(input, output);
create_or_update_lifetime(input, i);
create_or_update_lifetime(output, i);
aliases_[output].insert(input);
}
return true;
}
void AliasAnalyzer::maybe_update_aliases_from_schema(
const Node& node,
const c10::FastMap<std::string /* target */, FunctionSchema>& schemas) {
std::function<bool(size_t, size_t)> is_alias =
[]([[maybe_unused]] size_t input_idx,
[[maybe_unused]] size_t output_idx) { return true; };
const FunctionSchema* schema = nullptr;
if (auto schemaIt = schemas.find(std::string(node.target()));
schemaIt != schemas.end()) {
schema = &schemaIt->second;
}
if (!schema) {
VLOG(1) << "schema not found for " << node.target()
<< " assuming worst case aliasing";
}
for (size_t j = 0; j < node.numInputs(); j += 1) {
for (size_t k = 0; k < node.numOutputs(); k += 1) {
const Value* input = node.inputs().at(j).value;
const Value* output = node.outputs().at(k);
if (!schema || schema->alias(j, k)) {
VLOG(1) << node.target()
<< " may contain input/output alias: " << input->id() << " -> "
<< output->id();
aliases_[output].insert(input);
}
}
}
}
void AliasAnalyzer::create_or_update_lifetime(const Value* value, size_t i) {
if (auto [lifetimeIt, inserted] = lifetimes_.try_emplace(value, i, i);
!inserted) {
lifetimeIt->second.end = i;
}
}
void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) {
c10::FastSet<const Value*> extended;
for (auto nodeIt = graph.nodes().rbegin(); nodeIt != graph.nodes().rend();
++nodeIt) {
const auto& inputs = nodeIt->inputs();
for (const auto& input : inputs) {
if (auto aliasIt = aliases_.find(input.value);
aliasIt != aliases_.end()) {
const auto& alias = aliasIt->second;
for (const auto& src : alias) {
if (extended.find(src) != extended.end()) {
continue;
}
auto& eol = lifetimes_[src].end;
eol = lifetimes_[input.value].end;
VLOG(1) << "extended EOL of value " << src->id() << " to " << eol;
extended.insert(src);
if (eol == graph.nodes().size() - 1 /* aliases output */) {
values_associated_with_outputs_.insert(src);
}
}
}
}
}
}
void AliasAnalyzer::log_state() const {
if (!VLOG_IS_ON(
1) /* this is usually too large to be logged with VLOG directly */) {
return;
}
std::cout << [&]() -> std::string {
std::ostringstream ss;
ss << "[sigmoid layout planner] AliasAnalyzer ran....\n";
ss << "lifetimes:\n";
for (const auto& [v, lifetime] : lifetimes_) {
ss << " " << v->name() << ": [" << lifetime.start << ", " << lifetime.end
<< "]\n";
}
ss << "\naliases:\n";
for (const auto& [v, alias] : aliases_) {
ss << " " << v->name() << " -> ";
for (const auto* a : alias) {
ss << a->name() << ", ";
}
ss << "\n";
}
return ss.str();
}() << std::endl
<< std::flush;
}
} // namespace torch::nativert

View File

@ -0,0 +1,85 @@
#pragma once
#include <c10/util/FbcodeMaps.h>
#include <torch/nativert/executor/memory/FunctionSchema.h>
#include <torch/nativert/executor/memory/LayoutPlannerAlgorithm.h>
#include <torch/nativert/graph/Graph.h>
namespace torch::nativert {
class AliasAnalyzer {
public:
explicit AliasAnalyzer(
const Graph& graph,
const c10::FastMap<std::string /* target */, FunctionSchema>& schemas);
C10_ALWAYS_INLINE const AllocationLifetime& lifetime(
const Value* value) const {
return lifetimes_.at(value);
}
C10_ALWAYS_INLINE bool is_alias(const Value* value) const {
return aliases_.find(value) != aliases_.end();
}
C10_ALWAYS_INLINE bool is_storage_associated_with_output(
const Value* value) const {
return values_associated_with_outputs_.find(value) !=
values_associated_with_outputs_.end();
}
C10_ALWAYS_INLINE const c10::FastSet<const Value*>&
values_associated_with_output_storage() const {
return values_associated_with_outputs_;
}
private:
// listunpack operations who take a list that has
// been created with a listpack operation should
// be transparent with respect to aliasing
//
// e.g., given the op
// %t[] = prim.ListPack(l0=%t0, l1=%t1)
// %x1, %x2 = prim.ListUnpack(self=%t)
// x1 should directly alias t0
// and likewise x2 should directly alias t1
//
// this will make sure that the lifetimes of x1 and x2
// are not just the max of the lifetimes of t0 and t1
// which can make tensor-packing more efficient if list
// element EOL's differ by large amounts
bool /* applied */ update_aliases_if_packed_listunpack(
const Node& node,
size_t i);
// use the schema aliasing spec, or if none is provided,
// assume all outputs alias all inputs
void maybe_update_aliases_from_schema(
const Node& node,
const c10::FastMap<std::string /* target */, FunctionSchema>& schemas);
void create_or_update_lifetime(const Value* value, size_t i);
// work our way from the DAG's output node to the input node
// propagating the maximum EOL of all aliases back to their
// source value(s).
//
// in addition, if a graph output is an alias, we need to ensure
// that the source values are treated as graph outputs
// so that we don't free them before the graph output is copied
// back to the user (and we ignore them when creating a memory plan
// even if they aren't explicitly considered outputs)
void maybe_extend_lifetimes(const Graph& graph);
void log_state() const;
// mapping from alias to the set of values that it aliases
c10::FastMap<const Value*, c10::FastSet<const Value*>> aliases_;
c10::FastMap<const Value*, AllocationLifetime> lifetimes_;
// non-aliasing outputs or non-aliasing intermediates that are aliased by
// outputs
c10::FastSet<const Value*> values_associated_with_outputs_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,191 @@
#include <torch/nativert/executor/memory/LayoutManager.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <c10/core/CPUAllocator.h>
#include <c10/util/Enumerate.h>
namespace torch::nativert {
LayoutManager::LayoutManager(
LayoutPlanner& planner,
ExecutionFrame& parent_frame,
const torch::nativert::LayoutManagerSettings settings)
: planner_(planner), parent_frame_(parent_frame), settings_(settings) {
VLOG(1) << "layout manager created for execution frame";
}
void ContiguousLayoutBuffer::allocate(size_t size) {
VLOG(1) << "allocating " << size << " bytes";
if (C10_LIKELY(size_ > 0)) {
if (C10_LIKELY(
size <= size_) /* NOTE: size will be monotonically increasing */) {
return clear(size_);
} else {
deallocate();
}
}
data_ptr_ = c10::GetCPUCachingAllocator()->allocate(size);
size_ = size;
}
void LayoutManager::allocate() {
if (C10_UNLIKELY(state_ == LayoutManagerState::WaitingForValues)) {
return;
}
bool should_allocate_storages =
state_ == LayoutManagerState::AllocatingStorages;
ensure_managed_storages(/* allocate= */ should_allocate_storages);
planner_.with_plan([&](const auto& plan) { allocate_plan(plan); });
if (should_allocate_storages) {
state_ = LayoutManagerState::Running;
}
}
void LayoutManager::allocate_plan(const LayoutPlan& plan) {
if (C10_UNLIKELY(storage_impl_buffer_.size() == 0 || plan.total_size == 0)) {
return;
}
layout_buffer_.allocate(plan.total_size);
VLOG(1) << "allocated " << layout_buffer_.size()
<< " bytes for planned layout";
auto* storage_buf = storage_impl_buffer_.buffer();
for (const auto i : c10::irange(plan.allocations.size())) {
auto& planned_allocation = plan.allocations[i];
auto& local_max_nbytes = planned_tensors_max_nbytes_local_[i];
local_max_nbytes = std::max(local_max_nbytes, planned_allocation.size);
void* offset_ptr =
layout_buffer_.get_ptr_with_offset(planned_allocation.offset);
auto& storage = storage_buf[i];
// if the existing data ptr doesn't have an associated deleter then we
// will set the offset and size directly, as oposed to creating and
// swapping it with a new one
//
// apart from the first allocation when the storage still has the its
// allocator-created dataptr (https://fburl.com/code/u7dsspjm) whose
// deleter is non-null (https://fburl.com/code/7hiwo5zo), this should
// always be true
if (C10_LIKELY(
storage._mutable_data_ptr_no_checks().unsafe_reset_data_and_ctx(
offset_ptr))) {
storage.unsafe_set_nbytes(planned_allocation.size);
} else {
storage.set_data_ptr_noswap(at::DataPtr(
offset_ptr, offset_ptr, nullptr, c10::Device(c10::DeviceType::CPU)));
storage.set_nbytes(planned_allocation.size);
}
}
}
void LayoutManager::ensure_managed_storages(bool allocate) {
if (C10_UNLIKELY(planned_tensors_.empty())) {
return;
}
if (C10_UNLIKELY(allocate)) {
storage_impl_buffer_.allocate(planned_tensors_.size());
VLOG(1) << "allocated " << planned_tensors_.size() * sizeof(at::StorageImpl)
<< " bytes for contiguous storages";
}
auto* storage_buf = storage_impl_buffer_.buffer();
for (size_t i = 0; i < planned_tensors_.size(); i += 1) {
auto* tensor = planned_tensors_[i];
at::StorageImpl& storage = *tensor->storage().unsafeGetStorageImpl();
if (C10_UNLIKELY(allocate)) {
// from: https://fburl.com/code/4it00yph
//
// We want to manage StorageImpls' lifetimes ourselves, but TensorImpl
// expects to refcount them. unsafe_adapt_non_heap_allocated is our
// escape hatch: it sets the reference count for the StorageImpl to an
// impractically high value so that it will never get deallocated by
// intrusive_ptr, leaving us free to manage its lifetime as we see fit.
// (Note that allowing it to be deallocated by intrusive_ptr would be
// UB, because that would entail deleting an object that wasn't
// allocated with operator new.)
//
// For more information, see the doc comment for
// intrusive_ptr::unsafe_adapt_non_heap_allocated.
tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
c10::intrusive_ptr<at::StorageImpl>::unsafe_adapt_non_heap_allocated(
&storage_impl_buffer_.to_managed(storage), 1)));
} else if (
C10_UNLIKELY(
&storage !=
&storage_buf
[i]) /* managed storage was replaced for some reason */) {
storage.reset();
tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
c10::intrusive_ptr<at::StorageImpl>::unsafe_adapt_non_heap_allocated(
&storage_buf[i], 1)));
}
}
}
void LayoutManager::populate_tensor_values() {
CHECK(planned_tensors_.empty());
CHECK(unplanned_ivalues_.empty());
const auto& value_ids = planner_.get_planned_values();
planned_tensors_.resize(value_ids.size());
planned_tensors_max_nbytes_local_.resize(value_ids.size());
for (const auto&& [i, v] : c10::enumerate(value_ids)) {
planned_tensors_[i] = &parent_frame_.getIValue(v).toTensor();
}
const auto& unplanned_value_ids = planner_.get_unplanned_values();
unplanned_ivalues_.resize(unplanned_value_ids.size());
for (const auto&& [i, v] : c10::enumerate(unplanned_value_ids)) {
unplanned_ivalues_[i] = &parent_frame_.getIValue(v);
}
}
void LayoutManager::try_update_historical_max_nbytes() {
for (const auto i : c10::irange(planned_tensors_.size())) {
auto nbytes = get_aligned_nbytes(planned_tensors_[i]->nbytes());
if (auto& old_max = planned_tensors_max_nbytes_local_[i];
nbytes > old_max) {
old_max = nbytes;
planner_.try_update_max_size_at_index(i, nbytes);
}
}
}
void LayoutManager::deallocate_and_plan() {
const auto uninitialized = state_ == LayoutManagerState::WaitingForValues;
if (C10_UNLIKELY(uninitialized)) {
populate_tensor_values();
}
try_update_historical_max_nbytes();
if (C10_UNLIKELY(uninitialized)) {
planner_.start_worker_if_not_started();
}
if (C10_UNLIKELY(uninitialized)) {
state_ = LayoutManagerState::AllocatingStorages;
} else if (settings_.deallocateBetweenRequests()) {
layout_buffer_.deallocate();
}
for (auto* ivalue : unplanned_ivalues_) {
*ivalue = c10::IValue();
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,206 @@
#pragma once
#include <torch/nativert/executor/memory/LayoutPlanner.h>
#include <torch/nativert/executor/memory/LayoutPlannerAlgorithm.h>
#include <torch/nativert/executor/memory/LayoutPlannerSettings.h>
#include <c10/core/alignment.h>
#include <c10/core/impl/alloc_cpu.h>
namespace torch::nativert {
class ExecutionFrame;
struct ContiguousLayoutBuffer {
public:
ContiguousLayoutBuffer() = default;
~ContiguousLayoutBuffer() {
deallocate();
}
ContiguousLayoutBuffer(ContiguousLayoutBuffer&& other) = delete;
ContiguousLayoutBuffer(const ContiguousLayoutBuffer& other) = delete;
ContiguousLayoutBuffer operator=(ContiguousLayoutBuffer&& other) = delete;
ContiguousLayoutBuffer& operator=(const ContiguousLayoutBuffer& other) =
delete;
void* get_ptr_with_offset(size_t offset) {
void* raw_ptr = data_ptr_.get();
TORCH_CHECK_NOTNULL(raw_ptr);
TORCH_CHECK_LE(offset, size_);
return reinterpret_cast<void*>(
reinterpret_cast<uint8_t*>(raw_ptr) + offset);
}
size_t size() {
return size_;
}
void allocate(size_t size);
void deallocate() {
VLOG(1) << "deallocating layout buffer of size " << size_;
size_ = 0;
data_ptr_ = {};
}
void clear(size_t size) {
VLOG(1) << "clearing first " << size << "bytes of layout buffer of size "
<< size_;
TORCH_CHECK_LE(size, size_);
std::memset(data_ptr_.get(), 0, size);
}
private:
// the size of the buffer in bytes
size_t size_{0};
// the dataptr returned by the allocator
at::DataPtr data_ptr_{};
};
struct ContiguousStorageImplBuffer {
ContiguousStorageImplBuffer() = default;
~ContiguousStorageImplBuffer() {
deallocate();
}
ContiguousStorageImplBuffer(ContiguousStorageImplBuffer&& other) = delete;
ContiguousStorageImplBuffer(const ContiguousStorageImplBuffer& other) =
delete;
ContiguousStorageImplBuffer operator=(ContiguousStorageImplBuffer&& other) =
delete;
ContiguousStorageImplBuffer& operator=(
const ContiguousStorageImplBuffer& other) = delete;
void deallocate() {
if (buffer_ == nullptr) {
return;
}
for (const size_t idx : c10::irange(size_)) {
buffer_[idx].~StorageImpl();
}
delete[] reinterpret_cast<unsigned char*>(buffer_);
buffer_ = nullptr;
size_ = capacity_ = 0;
}
void allocate(size_t capacity) {
if (size_ > 0) {
deallocate();
}
capacity_ = capacity;
static_assert(alignof(at::StorageImpl) <= 8);
buffer_ = reinterpret_cast<at::StorageImpl*>(
new unsigned char[capacity * sizeof(at::StorageImpl)]);
}
size_t capacity() {
return capacity_;
}
size_t size() {
return size_;
}
c10::StorageImpl* buffer() const {
return buffer_;
}
c10::StorageImpl& at(size_t i) {
TORCH_CHECK_LT(i, size_)
<< "requested storage index " << i << " out of bounds " << size_;
return buffer_[i];
}
void reset_all() {
for (const size_t idx : c10::irange(size_)) {
buffer_[idx].reset();
}
}
c10::StorageImpl& to_managed(at::StorageImpl& s) {
TORCH_CHECK_LT(size_, capacity_);
return *(new (&buffer_[size_++]) at::StorageImpl(
at::StorageImpl::use_byte_size_t(),
static_cast<int64_t>(s.nbytes()),
s.allocator(),
s.resizable()));
}
private:
size_t size_{0};
size_t capacity_{0};
c10::StorageImpl* buffer_{nullptr};
};
enum class LayoutManagerState { WaitingForValues, AllocatingStorages, Running };
class LayoutManager {
public:
LayoutManager(
LayoutPlanner& planner,
ExecutionFrame& parent_frame,
torch::nativert::LayoutManagerSettings settings = {});
~LayoutManager() = default;
void allocate();
void deallocate_and_plan();
private:
#ifdef LayoutPlannerTests_TEST_FRIENDS
LayoutPlannerTests_TEST_FRIENDS;
#endif
static size_t get_aligned_nbytes(size_t nbytes) {
#if defined(__linux__) && !defined(__ANDROID__)
auto alignment = c10::c10_compute_alignment(nbytes);
#else
auto alignment = c10::gAlignment;
#endif
return ((nbytes) + alignment - 1) & (~(alignment - 1));
}
void allocate_plan(const LayoutPlan& plan);
void ensure_managed_storages(bool allocate);
void populate_tensor_values();
void try_update_historical_max_nbytes();
LayoutPlanner& planner_;
ExecutionFrame& parent_frame_;
std::vector<c10::IValue*> unplanned_ivalues_;
std::vector<const at::Tensor*> planned_tensors_;
std::vector<size_t> planned_tensors_max_nbytes_local_;
ContiguousLayoutBuffer layout_buffer_;
ContiguousStorageImplBuffer storage_impl_buffer_;
LayoutManagerState state_{LayoutManagerState::WaitingForValues};
torch::nativert::LayoutManagerSettings settings_;
};
class LayoutManagerGuard {
public:
explicit LayoutManagerGuard(LayoutManager& manager) : manager_(manager) {
manager_.allocate();
}
~LayoutManagerGuard() {
manager_.deallocate_and_plan();
}
LayoutManagerGuard(LayoutManagerGuard&& other) = delete;
LayoutManagerGuard(const LayoutManagerGuard& other) = delete;
LayoutManagerGuard operator=(LayoutManagerGuard&& other) = delete;
LayoutManagerGuard& operator=(const LayoutManagerGuard& other) = delete;
LayoutManager& manager_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,218 @@
#include <torch/nativert/executor/memory/LayoutPlanner.h>
#include <c10/util/CallOnce.h>
#include <c10/util/Enumerate.h>
#include <torch/nativert/executor/ExecutionPlanner.h>
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
#include <torch/nativert/executor/memory/Bump.h>
#include <torch/nativert/executor/memory/DisjointStorageGroups.h>
#include <torch/nativert/executor/memory/GreedyBySize.h>
namespace torch::nativert {
LayoutPlanner::LayoutPlanner(
const Graph& graph,
const c10::FastMap<std::string /* target */, FunctionSchema>& kernelSchemas,
const std::vector<bool>& persistentValues,
const torch::nativert::LayoutPlannerSettings& settings)
: managed_values_(graph.values().size()), settings_(settings) {
auto value_to_allocation_spec = c10::FastMap<const Value*, AllocationSpec>{};
auto alias_analyzer = AliasAnalyzer(graph, kernelSchemas);
std::set<const Value*> input_values_set_;
for (const auto* nv : graph.userInputs()) {
if (nv->type() == Type::Kind::Tensor) {
input_values_set_.insert(nv);
}
}
const auto& tensor_meta = graph.tensorValuesMeta();
for (auto&& [i, node] : at::enumerate(graph.nodes())) {
// only manage out variant values
if (const auto schemaIt = kernelSchemas.find(std::string(node.target()));
schemaIt == kernelSchemas.end() ||
schemaIt->second.kernel_kind() != OpKernelKind::kStaticDispatchKernel) {
VLOG(1) << "not able to plan outputs for node " << node.target()
<< " as it is derived from an unsupported kernel kind.";
continue;
}
for (const auto& output : node.outputs()) {
// don't manage persistent values
if (bool is_persistent = persistentValues[output->id()]; is_persistent) {
VLOG(1)
<< "not planning " << output->name()
<< " as it is a persistent value (likely a weight or const-folded)";
continue;
}
// only manage tensors
if (bool is_tensor = output->type().kind() == Type::Kind::Tensor;
!is_tensor) {
VLOG(1) << "not planning " << output->name()
<< " as it is not a raw tensor. type: " << output->type();
continue;
}
// output storage ownership must be given to the caller.
if (const auto& values_associated_with_output =
alias_analyzer.values_associated_with_output_storage();
values_associated_with_output.find(output) !=
values_associated_with_output.end()) {
VLOG(1)
<< "not planning " << output->name()
<< " as its underlying storage may be associated with a graph output";
continue;
}
// inputs are borrowed -- this is merely a sanity check
if (input_values_set_.find(output) != input_values_set_.end()) {
VLOG(1) << "not planning " << output->name()
<< " as it is a graph input that is borrowed from the user";
continue;
}
// don't plan aliases -- they don't own the associated dataptr
if (bool is_alias = alias_analyzer.is_alias(output); is_alias) {
VLOG(1) << "not planning " << output->name() << " as it is an alias";
continue;
}
if (bool is_consumed = output->users().size() > 0; !is_consumed) {
VLOG(1) << "not planning " << output->name() << " as it has no users";
continue;
}
if (auto meta_it = tensor_meta.find(std::string(output->name()));
meta_it != tensor_meta.end()) {
if (const auto& meta = meta_it->second; meta.device() == c10::kCPU) {
auto& spec = value_to_allocation_spec[output];
spec.lifetime = alias_analyzer.lifetime(output);
managed_values_[output->id()] = true;
continue;
} else {
VLOG(1) << "tensor " << output->name()
<< " not placed on cpu so we cannot plan it";
}
} else /* possible if runtime pass didn't populate meta info */ {
VLOG(1) << "tensor " << output->name() << " has no meta information";
}
managed_values_[output->id()] = true;
value_to_allocation_spec[output].lifetime =
alias_analyzer.lifetime(output);
}
}
LOG(INFO) << "layout planner created with " << value_to_allocation_spec.size()
<< " values";
switch (settings_.algorithmType()) {
case torch::nativert::LayoutPlannerAlgorithmType::Bump: {
algorithm_ = &BumpAllocationPlanner;
break;
}
case torch::nativert::LayoutPlannerAlgorithmType::GreedyBySize: {
algorithm_ = &GreedyBySizeAllocationPlanner;
break;
}
case LayoutPlannerAlgorithmType::DisjointStorageGroups: {
algorithm_ = &DisjointStorageGroupsPlanner;
break;
}
}
TORCH_CHECK_NOTNULL(algorithm_);
initialize_vectors(value_to_allocation_spec);
auto exec_planner = ExecutionPlanner{graph};
auto p = exec_planner.createPlan();
for (const auto& freeable : p->valuesToFree) {
for (const auto v : freeable) {
if (!is_managed(v)) {
unplanned_values_.push_back(v);
}
}
}
}
void LayoutPlanner::initialize_vectors(
c10::FastMap<const Value*, AllocationSpec> value_to_allocation_spec) {
size_t num_managed = value_to_allocation_spec.size();
planned_values_.resize(num_managed);
planned_allocation_specs_.resize(num_managed);
planned_values_historical_max_nbytes_ =
std::vector<std::atomic_size_t>(num_managed);
size_t i = 0;
for (auto& [v, spec] : value_to_allocation_spec) {
TORCH_CHECK_LE(spec.lifetime.start, spec.lifetime.end);
planned_values_[i] = v->id();
planned_values_historical_max_nbytes_[i] = spec.size;
planned_allocation_specs_[i] = std::move(spec);
i++;
}
// for sanity in case anyone tries to use this after this method
// is called with a bunch of junk (i.e., moved specs) in it
value_to_allocation_spec.clear();
}
const std::vector<ValueId>& LayoutPlanner::get_planned_values() const {
return planned_values_;
}
const std::vector<ValueId>& LayoutPlanner::get_unplanned_values() const {
return unplanned_values_;
}
void LayoutPlanner::start_worker_if_not_started() {
static c10::once_flag flag;
c10::call_once(flag, [&]() {
// make sure plan is populated by the time this
// returns for the first time :P
create_plan();
worker_ = std::thread([this]() {
run_periodic(std::bind(&LayoutPlanner::create_plan, this));
});
});
}
LayoutPlanner::~LayoutPlanner() {
{
std::unique_lock<std::mutex> l(mutex_);
stopped_ = true;
}
cv_.notify_one();
if (worker_.joinable()) {
worker_.join();
}
}
void LayoutPlanner::run_periodic(const std::function<void()>& f) {
std::unique_lock<std::mutex> l(mutex_);
while (!cv_.wait_for(
l, settings_.planningInterval(), [&]() { return stopped_; })) {
f();
}
}
void LayoutPlanner::create_plan() {
// update spec sizes to use historical maximums set
// by execution frames before creating the new plan
for (const auto i : c10::irange(planned_allocation_specs_.size())) {
auto& spec = planned_allocation_specs_[i];
spec.size = planned_values_historical_max_nbytes_[i].load(
std::memory_order_relaxed);
}
plan_.write([p_new = (*algorithm_)(planned_allocation_specs_)](
LayoutPlan& plan) { plan = p_new; });
}
} // namespace torch::nativert

View File

@ -0,0 +1,126 @@
#pragma once
#include <condition_variable>
#include <functional>
#include <thread>
#include <c10/macros/Macros.h>
#include <c10/util/FbcodeMaps.h>
#include <c10/util/LeftRight.h>
#include <torch/nativert/executor/memory/FunctionSchema.h>
#include <torch/nativert/executor/memory/LayoutPlannerAlgorithm.h>
#include <torch/nativert/executor/memory/LayoutPlannerSettings.h>
#include <torch/nativert/graph/Graph.h>
namespace {
constexpr inline std::memory_order drop_release(std::memory_order m) noexcept {
return (
m == std::memory_order_release
? std::memory_order_relaxed
: ((m == std::memory_order_acq_rel || m == std::memory_order_seq_cst)
? std::memory_order_acquire
: m));
}
// derivation of
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2024/p0493r5.pdf
template <typename T>
void atomic_set_max(
std::atomic<T>* pv,
typename std::atomic<T>::value_type v,
std::memory_order m = std::memory_order_seq_cst) noexcept {
auto const mr = drop_release(m);
auto t = (mr != m) ? pv->fetch_add(0, m) : pv->load(mr);
while (std::max(v, t) != t) {
if (pv->compare_exchange_weak(t, v, m, mr)) {
return;
}
}
}
} // namespace
namespace torch::nativert {
class LayoutPlanner {
public:
explicit LayoutPlanner(
const Graph& graph,
const c10::FastMap<std::string /* target */, FunctionSchema>&
kernelSchemas,
const std::vector<bool>& persistentValues,
const torch::nativert::LayoutPlannerSettings& settings);
~LayoutPlanner();
LayoutPlanner(LayoutPlanner&& other) = delete;
LayoutPlanner(const LayoutPlanner& other) = delete;
LayoutPlanner operator=(LayoutPlanner&& other) = delete;
LayoutPlanner& operator=(const LayoutPlanner& other) = delete;
void start_worker_if_not_started();
const std::vector<ValueId>& get_planned_values() const;
const std::vector<ValueId>& get_unplanned_values() const;
C10_ALWAYS_INLINE bool is_managed(ValueId id) {
TORCH_CHECK_LT(static_cast<size_t>(id), managed_values_.size());
return managed_values_[id];
}
C10_ALWAYS_INLINE void try_update_max_size_at_index(size_t idx, size_t size) {
atomic_set_max<size_t>(&planned_values_historical_max_nbytes_[idx], size);
}
C10_ALWAYS_INLINE
void with_plan(std::function<void(const LayoutPlan&)>&& cb) {
plan_.read(
std::forward<std::function<void(const LayoutPlan&)>>(std::move(cb)));
}
private:
#ifdef LayoutPlannerTests_TEST_FRIENDS
LayoutPlannerTests_TEST_FRIENDS;
#endif
// we need some way of mapping graph values to other information
// (e.g., allocation spec, max historical size)
//
// since there is a 1:1 mapping to/from each of these
// we can create+initialize them here
//
// note: planning algorithms are allowed to change the ordering
// of allocation specs -- so we pass the index of the spec during
// it's insertion s.t., each execution frame can use it to
// reference the correct associated max historical size / underlying
// tensor value
void initialize_vectors(
c10::FastMap<const Value*, AllocationSpec> value_to_allocation_spec);
void run_periodic(const std::function<void()>& f);
void create_plan();
// variables for managing the state of the
// interval worker thread that refreshes
// the plan
std::condition_variable cv_;
std::mutex mutex_;
bool stopped_{false};
std::thread worker_;
std::vector<ValueId> unplanned_values_;
std::vector<ValueId> planned_values_;
std::vector<AllocationSpec> planned_allocation_specs_;
std::vector<std::atomic_size_t> planned_values_historical_max_nbytes_;
// managed_values_[value_id] == true
// if graph.values()[value_id] has
// an associated allocation spec
std::vector<bool> managed_values_;
LayoutPlannerAlgorithm* algorithm_;
c10::LeftRight<LayoutPlan> plan_;
torch::nativert::LayoutPlannerSettings settings_;
};
} // namespace torch::nativert