From 7392470da4386e654f303eb526dfba7e7777b06b Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Fri, 27 Jun 2025 03:01:22 +0000 Subject: [PATCH] [nativert] alias analyzer + layout planner/manager to pytorch core (#156897) Summary: att Test Plan: ci - unit tests still have some unresolved deps but will move them later. Rollback Plan: Differential Revision: D77320950 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156897 Approved by: https://github.com/zhxchen17 --- build_variables.bzl | 3 + .../executor/memory/AliasAnalyzer.cpp | 173 ++++++++++++++ .../nativert/executor/memory/AliasAnalyzer.h | 85 +++++++ .../executor/memory/LayoutManager.cpp | 191 +++++++++++++++ .../nativert/executor/memory/LayoutManager.h | 206 +++++++++++++++++ .../executor/memory/LayoutPlanner.cpp | 218 ++++++++++++++++++ .../nativert/executor/memory/LayoutPlanner.h | 126 ++++++++++ 7 files changed, 1002 insertions(+) create mode 100644 torch/nativert/executor/memory/AliasAnalyzer.cpp create mode 100644 torch/nativert/executor/memory/AliasAnalyzer.h create mode 100644 torch/nativert/executor/memory/LayoutManager.cpp create mode 100644 torch/nativert/executor/memory/LayoutManager.h create mode 100644 torch/nativert/executor/memory/LayoutPlanner.cpp create mode 100644 torch/nativert/executor/memory/LayoutPlanner.h diff --git a/build_variables.bzl b/build_variables.bzl index 296260912b81..77fad7cdc5cb 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -619,6 +619,9 @@ libtorch_nativert_sources = [ "torch/nativert/kernels/CallTorchBindKernel.cpp", "torch/nativert/kernels/PrimKernelRegistry.cpp", "torch/nativert/executor/memory/DisjointStorageGroups.cpp", + "torch/nativert/executor/memory/AliasAnalyzer.cpp", + "torch/nativert/executor/memory/LayoutPlanner.cpp", + "torch/nativert/executor/memory/LayoutManager.cpp", ] torch_mobile_tracer_sources = [ diff --git a/torch/nativert/executor/memory/AliasAnalyzer.cpp b/torch/nativert/executor/memory/AliasAnalyzer.cpp new file mode 100644 index 000000000000..b3a6a9c93048 --- /dev/null +++ b/torch/nativert/executor/memory/AliasAnalyzer.cpp @@ -0,0 +1,173 @@ +#include + +#include + +namespace torch::nativert { + +AliasAnalyzer::AliasAnalyzer( + const Graph& graph, + const c10::FastMap& schemas) { + for (const auto&& [i, node] : c10::enumerate(graph.nodes())) { + for (const auto& input : node.inputs()) { + create_or_update_lifetime(input.value, i); + } + + for (const auto& output : node.outputs()) { + create_or_update_lifetime(output, i); + } + + if (update_aliases_if_packed_listunpack(node, i) /* applied? */) { + continue; + } + + maybe_update_aliases_from_schema(node, schemas); + } + + // set all non-aliasing outputs. outputs + // that are aliased will be set later when + // lifetimes are extended + for (const auto* output : graph.outputs()) { + if (!is_alias(output)) { + values_associated_with_outputs_.insert(output); + } + } + + maybe_extend_lifetimes(graph); + log_state(); +} + +bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack( + const Node& node, + size_t i) { + if (node.target() != "prim.ListUnpack") { + return false; + } + + const auto* list = node.inputs()[0].value; + + // we can't infer about how this list was made in this case + // so fallback to default always-aliasing behaviour + if (const auto* p = list->producer(); p && p->target() != "prim.ListPack") { + return false; + } + + const auto& list_elems = list->getListElements(); + TORCH_CHECK_EQ(list_elems.size(), node.numOutputs()); + + for (const auto j : c10::irange(node.numOutputs())) { + const Value* input = list_elems.at(j); + const Value* output = node.outputs().at(j); + + TORCH_CHECK_NE(input, output); + + create_or_update_lifetime(input, i); + create_or_update_lifetime(output, i); + + aliases_[output].insert(input); + } + + return true; +} + +void AliasAnalyzer::maybe_update_aliases_from_schema( + const Node& node, + const c10::FastMap& schemas) { + std::function is_alias = + []([[maybe_unused]] size_t input_idx, + [[maybe_unused]] size_t output_idx) { return true; }; + + const FunctionSchema* schema = nullptr; + if (auto schemaIt = schemas.find(std::string(node.target())); + schemaIt != schemas.end()) { + schema = &schemaIt->second; + } + + if (!schema) { + VLOG(1) << "schema not found for " << node.target() + << " assuming worst case aliasing"; + } + + for (size_t j = 0; j < node.numInputs(); j += 1) { + for (size_t k = 0; k < node.numOutputs(); k += 1) { + const Value* input = node.inputs().at(j).value; + const Value* output = node.outputs().at(k); + + if (!schema || schema->alias(j, k)) { + VLOG(1) << node.target() + << " may contain input/output alias: " << input->id() << " -> " + << output->id(); + aliases_[output].insert(input); + } + } + } +} + +void AliasAnalyzer::create_or_update_lifetime(const Value* value, size_t i) { + if (auto [lifetimeIt, inserted] = lifetimes_.try_emplace(value, i, i); + !inserted) { + lifetimeIt->second.end = i; + } +} + +void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) { + c10::FastSet extended; + + for (auto nodeIt = graph.nodes().rbegin(); nodeIt != graph.nodes().rend(); + ++nodeIt) { + const auto& inputs = nodeIt->inputs(); + for (const auto& input : inputs) { + if (auto aliasIt = aliases_.find(input.value); + aliasIt != aliases_.end()) { + const auto& alias = aliasIt->second; + for (const auto& src : alias) { + if (extended.find(src) != extended.end()) { + continue; + } + + auto& eol = lifetimes_[src].end; + eol = lifetimes_[input.value].end; + + VLOG(1) << "extended EOL of value " << src->id() << " to " << eol; + + extended.insert(src); + + if (eol == graph.nodes().size() - 1 /* aliases output */) { + values_associated_with_outputs_.insert(src); + } + } + } + } + } +} + +void AliasAnalyzer::log_state() const { + if (!VLOG_IS_ON( + 1) /* this is usually too large to be logged with VLOG directly */) { + return; + } + + std::cout << [&]() -> std::string { + std::ostringstream ss; + ss << "[sigmoid layout planner] AliasAnalyzer ran....\n"; + ss << "lifetimes:\n"; + + for (const auto& [v, lifetime] : lifetimes_) { + ss << " " << v->name() << ": [" << lifetime.start << ", " << lifetime.end + << "]\n"; + } + + ss << "\naliases:\n"; + for (const auto& [v, alias] : aliases_) { + ss << " " << v->name() << " -> "; + for (const auto* a : alias) { + ss << a->name() << ", "; + } + ss << "\n"; + } + + return ss.str(); + }() << std::endl + << std::flush; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/AliasAnalyzer.h b/torch/nativert/executor/memory/AliasAnalyzer.h new file mode 100644 index 000000000000..c9784d5d84ab --- /dev/null +++ b/torch/nativert/executor/memory/AliasAnalyzer.h @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch::nativert { + +class AliasAnalyzer { + public: + explicit AliasAnalyzer( + const Graph& graph, + const c10::FastMap& schemas); + + C10_ALWAYS_INLINE const AllocationLifetime& lifetime( + const Value* value) const { + return lifetimes_.at(value); + } + + C10_ALWAYS_INLINE bool is_alias(const Value* value) const { + return aliases_.find(value) != aliases_.end(); + } + + C10_ALWAYS_INLINE bool is_storage_associated_with_output( + const Value* value) const { + return values_associated_with_outputs_.find(value) != + values_associated_with_outputs_.end(); + } + + C10_ALWAYS_INLINE const c10::FastSet& + values_associated_with_output_storage() const { + return values_associated_with_outputs_; + } + + private: + // listunpack operations who take a list that has + // been created with a listpack operation should + // be transparent with respect to aliasing + // + // e.g., given the op + // %t[] = prim.ListPack(l0=%t0, l1=%t1) + // %x1, %x2 = prim.ListUnpack(self=%t) + // x1 should directly alias t0 + // and likewise x2 should directly alias t1 + // + // this will make sure that the lifetimes of x1 and x2 + // are not just the max of the lifetimes of t0 and t1 + // which can make tensor-packing more efficient if list + // element EOL's differ by large amounts + bool /* applied */ update_aliases_if_packed_listunpack( + const Node& node, + size_t i); + + // use the schema aliasing spec, or if none is provided, + // assume all outputs alias all inputs + void maybe_update_aliases_from_schema( + const Node& node, + const c10::FastMap& schemas); + + void create_or_update_lifetime(const Value* value, size_t i); + + // work our way from the DAG's output node to the input node + // propagating the maximum EOL of all aliases back to their + // source value(s). + // + // in addition, if a graph output is an alias, we need to ensure + // that the source values are treated as graph outputs + // so that we don't free them before the graph output is copied + // back to the user (and we ignore them when creating a memory plan + // even if they aren't explicitly considered outputs) + void maybe_extend_lifetimes(const Graph& graph); + + void log_state() const; + + // mapping from alias to the set of values that it aliases + c10::FastMap> aliases_; + c10::FastMap lifetimes_; + // non-aliasing outputs or non-aliasing intermediates that are aliased by + // outputs + c10::FastSet values_associated_with_outputs_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp new file mode 100644 index 000000000000..322cef1c1d78 --- /dev/null +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -0,0 +1,191 @@ +#include + +#include + +#include +#include + +namespace torch::nativert { + +LayoutManager::LayoutManager( + LayoutPlanner& planner, + ExecutionFrame& parent_frame, + const torch::nativert::LayoutManagerSettings settings) + : planner_(planner), parent_frame_(parent_frame), settings_(settings) { + VLOG(1) << "layout manager created for execution frame"; +} + +void ContiguousLayoutBuffer::allocate(size_t size) { + VLOG(1) << "allocating " << size << " bytes"; + if (C10_LIKELY(size_ > 0)) { + if (C10_LIKELY( + size <= size_) /* NOTE: size will be monotonically increasing */) { + return clear(size_); + } else { + deallocate(); + } + } + data_ptr_ = c10::GetCPUCachingAllocator()->allocate(size); + size_ = size; +} + +void LayoutManager::allocate() { + if (C10_UNLIKELY(state_ == LayoutManagerState::WaitingForValues)) { + return; + } + + bool should_allocate_storages = + state_ == LayoutManagerState::AllocatingStorages; + + ensure_managed_storages(/* allocate= */ should_allocate_storages); + + planner_.with_plan([&](const auto& plan) { allocate_plan(plan); }); + + if (should_allocate_storages) { + state_ = LayoutManagerState::Running; + } +} + +void LayoutManager::allocate_plan(const LayoutPlan& plan) { + if (C10_UNLIKELY(storage_impl_buffer_.size() == 0 || plan.total_size == 0)) { + return; + } + + layout_buffer_.allocate(plan.total_size); + VLOG(1) << "allocated " << layout_buffer_.size() + << " bytes for planned layout"; + + auto* storage_buf = storage_impl_buffer_.buffer(); + + for (const auto i : c10::irange(plan.allocations.size())) { + auto& planned_allocation = plan.allocations[i]; + auto& local_max_nbytes = planned_tensors_max_nbytes_local_[i]; + local_max_nbytes = std::max(local_max_nbytes, planned_allocation.size); + + void* offset_ptr = + layout_buffer_.get_ptr_with_offset(planned_allocation.offset); + auto& storage = storage_buf[i]; + + // if the existing data ptr doesn't have an associated deleter then we + // will set the offset and size directly, as oposed to creating and + // swapping it with a new one + // + // apart from the first allocation when the storage still has the its + // allocator-created dataptr (https://fburl.com/code/u7dsspjm) whose + // deleter is non-null (https://fburl.com/code/7hiwo5zo), this should + // always be true + if (C10_LIKELY( + storage._mutable_data_ptr_no_checks().unsafe_reset_data_and_ctx( + offset_ptr))) { + storage.unsafe_set_nbytes(planned_allocation.size); + } else { + storage.set_data_ptr_noswap(at::DataPtr( + offset_ptr, offset_ptr, nullptr, c10::Device(c10::DeviceType::CPU))); + storage.set_nbytes(planned_allocation.size); + } + } +} + +void LayoutManager::ensure_managed_storages(bool allocate) { + if (C10_UNLIKELY(planned_tensors_.empty())) { + return; + } + + if (C10_UNLIKELY(allocate)) { + storage_impl_buffer_.allocate(planned_tensors_.size()); + VLOG(1) << "allocated " << planned_tensors_.size() * sizeof(at::StorageImpl) + << " bytes for contiguous storages"; + } + + auto* storage_buf = storage_impl_buffer_.buffer(); + + for (size_t i = 0; i < planned_tensors_.size(); i += 1) { + auto* tensor = planned_tensors_[i]; + + at::StorageImpl& storage = *tensor->storage().unsafeGetStorageImpl(); + + if (C10_UNLIKELY(allocate)) { + // from: https://fburl.com/code/4it00yph + // + // We want to manage StorageImpls' lifetimes ourselves, but TensorImpl + // expects to refcount them. unsafe_adapt_non_heap_allocated is our + // escape hatch: it sets the reference count for the StorageImpl to an + // impractically high value so that it will never get deallocated by + // intrusive_ptr, leaving us free to manage its lifetime as we see fit. + // (Note that allowing it to be deallocated by intrusive_ptr would be + // UB, because that would entail deleting an object that wasn't + // allocated with operator new.) + // + // For more information, see the doc comment for + // intrusive_ptr::unsafe_adapt_non_heap_allocated. + tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage( + c10::intrusive_ptr::unsafe_adapt_non_heap_allocated( + &storage_impl_buffer_.to_managed(storage), 1))); + } else if ( + C10_UNLIKELY( + &storage != + &storage_buf + [i]) /* managed storage was replaced for some reason */) { + storage.reset(); + tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage( + c10::intrusive_ptr::unsafe_adapt_non_heap_allocated( + &storage_buf[i], 1))); + } + } +} + +void LayoutManager::populate_tensor_values() { + CHECK(planned_tensors_.empty()); + CHECK(unplanned_ivalues_.empty()); + + const auto& value_ids = planner_.get_planned_values(); + planned_tensors_.resize(value_ids.size()); + planned_tensors_max_nbytes_local_.resize(value_ids.size()); + + for (const auto&& [i, v] : c10::enumerate(value_ids)) { + planned_tensors_[i] = &parent_frame_.getIValue(v).toTensor(); + } + + const auto& unplanned_value_ids = planner_.get_unplanned_values(); + unplanned_ivalues_.resize(unplanned_value_ids.size()); + for (const auto&& [i, v] : c10::enumerate(unplanned_value_ids)) { + unplanned_ivalues_[i] = &parent_frame_.getIValue(v); + } +} + +void LayoutManager::try_update_historical_max_nbytes() { + for (const auto i : c10::irange(planned_tensors_.size())) { + auto nbytes = get_aligned_nbytes(planned_tensors_[i]->nbytes()); + if (auto& old_max = planned_tensors_max_nbytes_local_[i]; + nbytes > old_max) { + old_max = nbytes; + planner_.try_update_max_size_at_index(i, nbytes); + } + } +} + +void LayoutManager::deallocate_and_plan() { + const auto uninitialized = state_ == LayoutManagerState::WaitingForValues; + + if (C10_UNLIKELY(uninitialized)) { + populate_tensor_values(); + } + + try_update_historical_max_nbytes(); + + if (C10_UNLIKELY(uninitialized)) { + planner_.start_worker_if_not_started(); + } + + if (C10_UNLIKELY(uninitialized)) { + state_ = LayoutManagerState::AllocatingStorages; + } else if (settings_.deallocateBetweenRequests()) { + layout_buffer_.deallocate(); + } + + for (auto* ivalue : unplanned_ivalues_) { + *ivalue = c10::IValue(); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutManager.h b/torch/nativert/executor/memory/LayoutManager.h new file mode 100644 index 000000000000..76f658e09d08 --- /dev/null +++ b/torch/nativert/executor/memory/LayoutManager.h @@ -0,0 +1,206 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::nativert { + +class ExecutionFrame; + +struct ContiguousLayoutBuffer { + public: + ContiguousLayoutBuffer() = default; + ~ContiguousLayoutBuffer() { + deallocate(); + } + + ContiguousLayoutBuffer(ContiguousLayoutBuffer&& other) = delete; + ContiguousLayoutBuffer(const ContiguousLayoutBuffer& other) = delete; + ContiguousLayoutBuffer operator=(ContiguousLayoutBuffer&& other) = delete; + ContiguousLayoutBuffer& operator=(const ContiguousLayoutBuffer& other) = + delete; + + void* get_ptr_with_offset(size_t offset) { + void* raw_ptr = data_ptr_.get(); + TORCH_CHECK_NOTNULL(raw_ptr); + TORCH_CHECK_LE(offset, size_); + return reinterpret_cast( + reinterpret_cast(raw_ptr) + offset); + } + + size_t size() { + return size_; + } + + void allocate(size_t size); + + void deallocate() { + VLOG(1) << "deallocating layout buffer of size " << size_; + size_ = 0; + data_ptr_ = {}; + } + + void clear(size_t size) { + VLOG(1) << "clearing first " << size << "bytes of layout buffer of size " + << size_; + TORCH_CHECK_LE(size, size_); + std::memset(data_ptr_.get(), 0, size); + } + + private: + // the size of the buffer in bytes + size_t size_{0}; + + // the dataptr returned by the allocator + at::DataPtr data_ptr_{}; +}; + +struct ContiguousStorageImplBuffer { + ContiguousStorageImplBuffer() = default; + ~ContiguousStorageImplBuffer() { + deallocate(); + } + + ContiguousStorageImplBuffer(ContiguousStorageImplBuffer&& other) = delete; + ContiguousStorageImplBuffer(const ContiguousStorageImplBuffer& other) = + delete; + ContiguousStorageImplBuffer operator=(ContiguousStorageImplBuffer&& other) = + delete; + ContiguousStorageImplBuffer& operator=( + const ContiguousStorageImplBuffer& other) = delete; + + void deallocate() { + if (buffer_ == nullptr) { + return; + } + + for (const size_t idx : c10::irange(size_)) { + buffer_[idx].~StorageImpl(); + } + + delete[] reinterpret_cast(buffer_); + buffer_ = nullptr; + size_ = capacity_ = 0; + } + + void allocate(size_t capacity) { + if (size_ > 0) { + deallocate(); + } + + capacity_ = capacity; + + static_assert(alignof(at::StorageImpl) <= 8); + buffer_ = reinterpret_cast( + new unsigned char[capacity * sizeof(at::StorageImpl)]); + } + + size_t capacity() { + return capacity_; + } + + size_t size() { + return size_; + } + + c10::StorageImpl* buffer() const { + return buffer_; + } + + c10::StorageImpl& at(size_t i) { + TORCH_CHECK_LT(i, size_) + << "requested storage index " << i << " out of bounds " << size_; + return buffer_[i]; + } + + void reset_all() { + for (const size_t idx : c10::irange(size_)) { + buffer_[idx].reset(); + } + } + + c10::StorageImpl& to_managed(at::StorageImpl& s) { + TORCH_CHECK_LT(size_, capacity_); + return *(new (&buffer_[size_++]) at::StorageImpl( + at::StorageImpl::use_byte_size_t(), + static_cast(s.nbytes()), + s.allocator(), + s.resizable())); + } + + private: + size_t size_{0}; + size_t capacity_{0}; + c10::StorageImpl* buffer_{nullptr}; +}; + +enum class LayoutManagerState { WaitingForValues, AllocatingStorages, Running }; + +class LayoutManager { + public: + LayoutManager( + LayoutPlanner& planner, + ExecutionFrame& parent_frame, + torch::nativert::LayoutManagerSettings settings = {}); + ~LayoutManager() = default; + + void allocate(); + void deallocate_and_plan(); + + private: +#ifdef LayoutPlannerTests_TEST_FRIENDS + LayoutPlannerTests_TEST_FRIENDS; +#endif + + static size_t get_aligned_nbytes(size_t nbytes) { +#if defined(__linux__) && !defined(__ANDROID__) + auto alignment = c10::c10_compute_alignment(nbytes); +#else + auto alignment = c10::gAlignment; +#endif + return ((nbytes) + alignment - 1) & (~(alignment - 1)); + } + + void allocate_plan(const LayoutPlan& plan); + void ensure_managed_storages(bool allocate); + + void populate_tensor_values(); + void try_update_historical_max_nbytes(); + + LayoutPlanner& planner_; + ExecutionFrame& parent_frame_; + + std::vector unplanned_ivalues_; + + std::vector planned_tensors_; + std::vector planned_tensors_max_nbytes_local_; + + ContiguousLayoutBuffer layout_buffer_; + ContiguousStorageImplBuffer storage_impl_buffer_; + + LayoutManagerState state_{LayoutManagerState::WaitingForValues}; + torch::nativert::LayoutManagerSettings settings_; +}; + +class LayoutManagerGuard { + public: + explicit LayoutManagerGuard(LayoutManager& manager) : manager_(manager) { + manager_.allocate(); + } + ~LayoutManagerGuard() { + manager_.deallocate_and_plan(); + } + + LayoutManagerGuard(LayoutManagerGuard&& other) = delete; + LayoutManagerGuard(const LayoutManagerGuard& other) = delete; + LayoutManagerGuard operator=(LayoutManagerGuard&& other) = delete; + LayoutManagerGuard& operator=(const LayoutManagerGuard& other) = delete; + + LayoutManager& manager_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutPlanner.cpp b/torch/nativert/executor/memory/LayoutPlanner.cpp new file mode 100644 index 000000000000..87913304d4d7 --- /dev/null +++ b/torch/nativert/executor/memory/LayoutPlanner.cpp @@ -0,0 +1,218 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::nativert { + +LayoutPlanner::LayoutPlanner( + const Graph& graph, + const c10::FastMap& kernelSchemas, + const std::vector& persistentValues, + const torch::nativert::LayoutPlannerSettings& settings) + : managed_values_(graph.values().size()), settings_(settings) { + auto value_to_allocation_spec = c10::FastMap{}; + auto alias_analyzer = AliasAnalyzer(graph, kernelSchemas); + + std::set input_values_set_; + for (const auto* nv : graph.userInputs()) { + if (nv->type() == Type::Kind::Tensor) { + input_values_set_.insert(nv); + } + } + + const auto& tensor_meta = graph.tensorValuesMeta(); + + for (auto&& [i, node] : at::enumerate(graph.nodes())) { + // only manage out variant values + if (const auto schemaIt = kernelSchemas.find(std::string(node.target())); + schemaIt == kernelSchemas.end() || + schemaIt->second.kernel_kind() != OpKernelKind::kStaticDispatchKernel) { + VLOG(1) << "not able to plan outputs for node " << node.target() + << " as it is derived from an unsupported kernel kind."; + continue; + } + + for (const auto& output : node.outputs()) { + // don't manage persistent values + if (bool is_persistent = persistentValues[output->id()]; is_persistent) { + VLOG(1) + << "not planning " << output->name() + << " as it is a persistent value (likely a weight or const-folded)"; + continue; + } + + // only manage tensors + if (bool is_tensor = output->type().kind() == Type::Kind::Tensor; + !is_tensor) { + VLOG(1) << "not planning " << output->name() + << " as it is not a raw tensor. type: " << output->type(); + continue; + } + + // output storage ownership must be given to the caller. + if (const auto& values_associated_with_output = + alias_analyzer.values_associated_with_output_storage(); + values_associated_with_output.find(output) != + values_associated_with_output.end()) { + VLOG(1) + << "not planning " << output->name() + << " as its underlying storage may be associated with a graph output"; + continue; + } + + // inputs are borrowed -- this is merely a sanity check + if (input_values_set_.find(output) != input_values_set_.end()) { + VLOG(1) << "not planning " << output->name() + << " as it is a graph input that is borrowed from the user"; + continue; + } + + // don't plan aliases -- they don't own the associated dataptr + if (bool is_alias = alias_analyzer.is_alias(output); is_alias) { + VLOG(1) << "not planning " << output->name() << " as it is an alias"; + continue; + } + + if (bool is_consumed = output->users().size() > 0; !is_consumed) { + VLOG(1) << "not planning " << output->name() << " as it has no users"; + continue; + } + + if (auto meta_it = tensor_meta.find(std::string(output->name())); + meta_it != tensor_meta.end()) { + if (const auto& meta = meta_it->second; meta.device() == c10::kCPU) { + auto& spec = value_to_allocation_spec[output]; + spec.lifetime = alias_analyzer.lifetime(output); + managed_values_[output->id()] = true; + continue; + } else { + VLOG(1) << "tensor " << output->name() + << " not placed on cpu so we cannot plan it"; + } + } else /* possible if runtime pass didn't populate meta info */ { + VLOG(1) << "tensor " << output->name() << " has no meta information"; + } + + managed_values_[output->id()] = true; + value_to_allocation_spec[output].lifetime = + alias_analyzer.lifetime(output); + } + } + + LOG(INFO) << "layout planner created with " << value_to_allocation_spec.size() + << " values"; + + switch (settings_.algorithmType()) { + case torch::nativert::LayoutPlannerAlgorithmType::Bump: { + algorithm_ = &BumpAllocationPlanner; + break; + } + case torch::nativert::LayoutPlannerAlgorithmType::GreedyBySize: { + algorithm_ = &GreedyBySizeAllocationPlanner; + break; + } + case LayoutPlannerAlgorithmType::DisjointStorageGroups: { + algorithm_ = &DisjointStorageGroupsPlanner; + break; + } + } + + TORCH_CHECK_NOTNULL(algorithm_); + + initialize_vectors(value_to_allocation_spec); + + auto exec_planner = ExecutionPlanner{graph}; + auto p = exec_planner.createPlan(); + for (const auto& freeable : p->valuesToFree) { + for (const auto v : freeable) { + if (!is_managed(v)) { + unplanned_values_.push_back(v); + } + } + } +} + +void LayoutPlanner::initialize_vectors( + c10::FastMap value_to_allocation_spec) { + size_t num_managed = value_to_allocation_spec.size(); + + planned_values_.resize(num_managed); + planned_allocation_specs_.resize(num_managed); + planned_values_historical_max_nbytes_ = + std::vector(num_managed); + + size_t i = 0; + for (auto& [v, spec] : value_to_allocation_spec) { + TORCH_CHECK_LE(spec.lifetime.start, spec.lifetime.end); + + planned_values_[i] = v->id(); + planned_values_historical_max_nbytes_[i] = spec.size; + planned_allocation_specs_[i] = std::move(spec); + + i++; + } + + // for sanity in case anyone tries to use this after this method + // is called with a bunch of junk (i.e., moved specs) in it + value_to_allocation_spec.clear(); +} + +const std::vector& LayoutPlanner::get_planned_values() const { + return planned_values_; +} + +const std::vector& LayoutPlanner::get_unplanned_values() const { + return unplanned_values_; +} + +void LayoutPlanner::start_worker_if_not_started() { + static c10::once_flag flag; + c10::call_once(flag, [&]() { + // make sure plan is populated by the time this + // returns for the first time :P + create_plan(); + worker_ = std::thread([this]() { + run_periodic(std::bind(&LayoutPlanner::create_plan, this)); + }); + }); +} + +LayoutPlanner::~LayoutPlanner() { + { + std::unique_lock l(mutex_); + stopped_ = true; + } + cv_.notify_one(); + if (worker_.joinable()) { + worker_.join(); + } +} + +void LayoutPlanner::run_periodic(const std::function& f) { + std::unique_lock l(mutex_); + while (!cv_.wait_for( + l, settings_.planningInterval(), [&]() { return stopped_; })) { + f(); + } +} + +void LayoutPlanner::create_plan() { + // update spec sizes to use historical maximums set + // by execution frames before creating the new plan + for (const auto i : c10::irange(planned_allocation_specs_.size())) { + auto& spec = planned_allocation_specs_[i]; + spec.size = planned_values_historical_max_nbytes_[i].load( + std::memory_order_relaxed); + } + plan_.write([p_new = (*algorithm_)(planned_allocation_specs_)]( + LayoutPlan& plan) { plan = p_new; }); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutPlanner.h b/torch/nativert/executor/memory/LayoutPlanner.h new file mode 100644 index 000000000000..6382fdbba01b --- /dev/null +++ b/torch/nativert/executor/memory/LayoutPlanner.h @@ -0,0 +1,126 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace { +constexpr inline std::memory_order drop_release(std::memory_order m) noexcept { + return ( + m == std::memory_order_release + ? std::memory_order_relaxed + : ((m == std::memory_order_acq_rel || m == std::memory_order_seq_cst) + ? std::memory_order_acquire + : m)); +} +// derivation of +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2024/p0493r5.pdf +template +void atomic_set_max( + std::atomic* pv, + typename std::atomic::value_type v, + std::memory_order m = std::memory_order_seq_cst) noexcept { + auto const mr = drop_release(m); + auto t = (mr != m) ? pv->fetch_add(0, m) : pv->load(mr); + while (std::max(v, t) != t) { + if (pv->compare_exchange_weak(t, v, m, mr)) { + return; + } + } +} +} // namespace + +namespace torch::nativert { + +class LayoutPlanner { + public: + explicit LayoutPlanner( + const Graph& graph, + const c10::FastMap& + kernelSchemas, + const std::vector& persistentValues, + const torch::nativert::LayoutPlannerSettings& settings); + ~LayoutPlanner(); + + LayoutPlanner(LayoutPlanner&& other) = delete; + LayoutPlanner(const LayoutPlanner& other) = delete; + LayoutPlanner operator=(LayoutPlanner&& other) = delete; + LayoutPlanner& operator=(const LayoutPlanner& other) = delete; + + void start_worker_if_not_started(); + + const std::vector& get_planned_values() const; + const std::vector& get_unplanned_values() const; + + C10_ALWAYS_INLINE bool is_managed(ValueId id) { + TORCH_CHECK_LT(static_cast(id), managed_values_.size()); + return managed_values_[id]; + } + + C10_ALWAYS_INLINE void try_update_max_size_at_index(size_t idx, size_t size) { + atomic_set_max(&planned_values_historical_max_nbytes_[idx], size); + } + + C10_ALWAYS_INLINE + void with_plan(std::function&& cb) { + plan_.read( + std::forward>(std::move(cb))); + } + + private: +#ifdef LayoutPlannerTests_TEST_FRIENDS + LayoutPlannerTests_TEST_FRIENDS; +#endif + + // we need some way of mapping graph values to other information + // (e.g., allocation spec, max historical size) + // + // since there is a 1:1 mapping to/from each of these + // we can create+initialize them here + // + // note: planning algorithms are allowed to change the ordering + // of allocation specs -- so we pass the index of the spec during + // it's insertion s.t., each execution frame can use it to + // reference the correct associated max historical size / underlying + // tensor value + void initialize_vectors( + c10::FastMap value_to_allocation_spec); + + void run_periodic(const std::function& f); + void create_plan(); + + // variables for managing the state of the + // interval worker thread that refreshes + // the plan + std::condition_variable cv_; + std::mutex mutex_; + bool stopped_{false}; + std::thread worker_; + + std::vector unplanned_values_; + + std::vector planned_values_; + std::vector planned_allocation_specs_; + std::vector planned_values_historical_max_nbytes_; + + // managed_values_[value_id] == true + // if graph.values()[value_id] has + // an associated allocation spec + std::vector managed_values_; + + LayoutPlannerAlgorithm* algorithm_; + c10::LeftRight plan_; + + torch::nativert::LayoutPlannerSettings settings_; +}; + +} // namespace torch::nativert