[TensorExpr] Add classes for memory management in tensor expressions. (#33216)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33216

All tensor expressions belong to a kernel arena and are freed when the
arena is destroyed. Until it is destroyed, all expressions stay valid.

Test Plan: Imported from OSS

Differential Revision: D19848382

Pulled By: ZolotukhinM

fbshipit-source-id: a581ea2b635b9ba2cc53949616a13d8d3a47caae
This commit is contained in:
Mikhail Zolotukhin
2020-02-21 13:06:13 -08:00
committed by Facebook Github Bot
parent 616beb1412
commit 089d658153
4 changed files with 121 additions and 0 deletions

View File

@ -454,6 +454,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp
${TORCH_SRC_DIR}/csrc/jit/function.cpp
${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
)
if (NOT INTERN_DISABLE_MOBILE_INTERP)

View File

@ -190,9 +190,11 @@ libtorch_sources = [
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
"torch/csrc/jit/mobile/interpreter.cpp",
"torch/csrc/jit/mobile/type_parser.cpp",
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
"torch/csrc/utils/byte_order.cpp",
"torch/csrc/utils/tensor_flatten.cpp",
"torch/csrc/utils/variadic.cpp",
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
]
libtorch_cuda_sources = [

View File

@ -0,0 +1,56 @@
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
namespace torch {
namespace jit {
namespace tensorexpr {
namespace {
// Define in an anonymous namespace to hide this symbol from other compilation
// units
thread_local KernelArena* current_arena = nullptr;
}
KernelArena::~KernelArena() {
for (KernelScopedObject* p : kernel_objects_) {
delete p;
}
}
KernelScopedObject::KernelScopedObject() {
KernelArena* kernel = KernelArena::GetCurrentKernelArena();
kernel->kernel_objects_.push_back(this);
}
static std::vector<KernelArena*>& GetKernelArenaStack() {
thread_local std::vector<KernelArena*> kernel_arena_stack;
return kernel_arena_stack;
}
void KernelArena::SetCurrentKernelArena(KernelArena *new_kernel_arena) {
current_arena = new_kernel_arena;
}
KernelArena* KernelArena::GetCurrentKernelArena() {
return current_arena;
}
KernelScope::KernelScope() : owning_(true) {
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
KernelArena::SetCurrentKernelArena(new KernelArena);
}
KernelScope::KernelScope(KernelArena* arena_) : owning_(false) {
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
KernelArena::SetCurrentKernelArena(arena_);
}
KernelScope::~KernelScope() {
if (owning_) {
delete KernelArena::GetCurrentKernelArena();
}
KernelArena::SetCurrentKernelArena(old_kernel_arena_);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,61 @@
#pragma once
#include <vector>
#include "torch/csrc/WindowsTorchApiMacro.h"
namespace torch {
namespace jit {
namespace tensorexpr {
class KernelScopedObject;
// An arena that manages all the underlying kernel-scoped objects.
class KernelArena {
public:
static KernelArena* GetCurrentKernelArena();
static void SetCurrentKernelArena(KernelArena* new_arena);
TORCH_API KernelArena() {}
TORCH_API ~KernelArena();
private:
KernelArena(const KernelArena&) = delete;
KernelArena& operator=(const KernelArena&) = delete;
friend class KernelScopedObject;
std::vector<KernelScopedObject*> kernel_objects_; // owned
};
// A RAII convenience wrapper on top of a kernel.
// It either creates or takes an existing Kernel and sets it as the current
// Kernel. When this object is destroyed, the previous Kernel is set as current,
// and the created kernel is freed. If the kernel was passed, it stays alive.
class KernelScope {
public:
TORCH_API KernelScope();
TORCH_API explicit KernelScope(KernelArena* arena_);
TORCH_API ~KernelScope();
private:
KernelScope(const KernelScope&) = delete;
KernelScope& operator=(const KernelScope&) = delete;
KernelArena* kernel_arena_ = nullptr; // arena to be used in this scope
KernelArena* old_kernel_arena_ =
nullptr; // previous arena, will be restored in destructor
bool owning_ = false; // determines whether the arena will be freed along with
// the scope object
};
// The base object managed by the Kernel.
// The object must be created through "new", and when the Kernel is destroyed,
// All its registered objects are destroyed through "delete".
class TORCH_API KernelScopedObject {
public:
KernelScopedObject();
virtual ~KernelScopedObject() = default;
private:
KernelScopedObject(const KernelScopedObject&) = delete;
KernelScopedObject& operator=(const KernelScopedObject&) = delete;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch