diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 4850c0dd8842..ca330e34fa30 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -454,6 +454,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp ${TORCH_SRC_DIR}/csrc/jit/function.cpp ${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp + + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp ) if (NOT INTERN_DISABLE_MOBILE_INTERP) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 63a22ec90274..469a02d1e4c3 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -190,9 +190,11 @@ libtorch_sources = [ "torch/csrc/jit/mobile/register_mobile_ops.cpp", "torch/csrc/jit/mobile/interpreter.cpp", "torch/csrc/jit/mobile/type_parser.cpp", + "torch/csrc/jit/tensorexpr/mem_arena.cpp", "torch/csrc/utils/byte_order.cpp", "torch/csrc/utils/tensor_flatten.cpp", "torch/csrc/utils/variadic.cpp", + "torch/csrc/jit/tensorexpr/mem_arena.cpp", ] libtorch_cuda_sources = [ diff --git a/torch/csrc/jit/tensorexpr/mem_arena.cpp b/torch/csrc/jit/tensorexpr/mem_arena.cpp new file mode 100644 index 000000000000..c011c659306a --- /dev/null +++ b/torch/csrc/jit/tensorexpr/mem_arena.cpp @@ -0,0 +1,56 @@ +#include "torch/csrc/jit/tensorexpr/mem_arena.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +namespace { +// Define in an anonymous namespace to hide this symbol from other compilation +// units +thread_local KernelArena* current_arena = nullptr; +} + +KernelArena::~KernelArena() { + for (KernelScopedObject* p : kernel_objects_) { + delete p; + } +} + +KernelScopedObject::KernelScopedObject() { + KernelArena* kernel = KernelArena::GetCurrentKernelArena(); + kernel->kernel_objects_.push_back(this); +} + +static std::vector& GetKernelArenaStack() { + thread_local std::vector kernel_arena_stack; + return kernel_arena_stack; +} + +void KernelArena::SetCurrentKernelArena(KernelArena *new_kernel_arena) { + current_arena = new_kernel_arena; +} + +KernelArena* KernelArena::GetCurrentKernelArena() { + return current_arena; +} + +KernelScope::KernelScope() : owning_(true) { + old_kernel_arena_ = KernelArena::GetCurrentKernelArena(); + KernelArena::SetCurrentKernelArena(new KernelArena); +} + +KernelScope::KernelScope(KernelArena* arena_) : owning_(false) { + old_kernel_arena_ = KernelArena::GetCurrentKernelArena(); + KernelArena::SetCurrentKernelArena(arena_); +} + +KernelScope::~KernelScope() { + if (owning_) { + delete KernelArena::GetCurrentKernelArena(); + } + KernelArena::SetCurrentKernelArena(old_kernel_arena_); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/mem_arena.h b/torch/csrc/jit/tensorexpr/mem_arena.h new file mode 100644 index 000000000000..121bdb60e02a --- /dev/null +++ b/torch/csrc/jit/tensorexpr/mem_arena.h @@ -0,0 +1,61 @@ +#pragma once +#include +#include "torch/csrc/WindowsTorchApiMacro.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +class KernelScopedObject; + +// An arena that manages all the underlying kernel-scoped objects. +class KernelArena { + public: + static KernelArena* GetCurrentKernelArena(); + static void SetCurrentKernelArena(KernelArena* new_arena); + TORCH_API KernelArena() {} + TORCH_API ~KernelArena(); + + private: + KernelArena(const KernelArena&) = delete; + KernelArena& operator=(const KernelArena&) = delete; + friend class KernelScopedObject; + std::vector kernel_objects_; // owned +}; + +// A RAII convenience wrapper on top of a kernel. +// It either creates or takes an existing Kernel and sets it as the current +// Kernel. When this object is destroyed, the previous Kernel is set as current, +// and the created kernel is freed. If the kernel was passed, it stays alive. +class KernelScope { + public: + TORCH_API KernelScope(); + TORCH_API explicit KernelScope(KernelArena* arena_); + TORCH_API ~KernelScope(); + + private: + KernelScope(const KernelScope&) = delete; + KernelScope& operator=(const KernelScope&) = delete; + KernelArena* kernel_arena_ = nullptr; // arena to be used in this scope + KernelArena* old_kernel_arena_ = + nullptr; // previous arena, will be restored in destructor + bool owning_ = false; // determines whether the arena will be freed along with + // the scope object +}; + +// The base object managed by the Kernel. +// The object must be created through "new", and when the Kernel is destroyed, +// All its registered objects are destroyed through "delete". +class TORCH_API KernelScopedObject { + public: + KernelScopedObject(); + virtual ~KernelScopedObject() = default; + + private: + KernelScopedObject(const KernelScopedObject&) = delete; + KernelScopedObject& operator=(const KernelScopedObject&) = delete; +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch