mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Add classes for memory management in tensor expressions. (#33216)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33216 All tensor expressions belong to a kernel arena and are freed when the arena is destroyed. Until it is destroyed, all expressions stay valid. Test Plan: Imported from OSS Differential Revision: D19848382 Pulled By: ZolotukhinM fbshipit-source-id: a581ea2b635b9ba2cc53949616a13d8d3a47caae
This commit is contained in:
committed by
Facebook Github Bot
parent
616beb1412
commit
089d658153
@ -454,6 +454,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/function.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp
|
||||
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
|
||||
)
|
||||
|
||||
if (NOT INTERN_DISABLE_MOBILE_INTERP)
|
||||
|
@ -190,9 +190,11 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
|
||||
"torch/csrc/jit/mobile/interpreter.cpp",
|
||||
"torch/csrc/jit/mobile/type_parser.cpp",
|
||||
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
|
||||
"torch/csrc/utils/byte_order.cpp",
|
||||
"torch/csrc/utils/tensor_flatten.cpp",
|
||||
"torch/csrc/utils/variadic.cpp",
|
||||
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
|
||||
]
|
||||
|
||||
libtorch_cuda_sources = [
|
||||
|
56
torch/csrc/jit/tensorexpr/mem_arena.cpp
Normal file
56
torch/csrc/jit/tensorexpr/mem_arena.cpp
Normal file
@ -0,0 +1,56 @@
|
||||
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
namespace {
|
||||
// Define in an anonymous namespace to hide this symbol from other compilation
|
||||
// units
|
||||
thread_local KernelArena* current_arena = nullptr;
|
||||
}
|
||||
|
||||
KernelArena::~KernelArena() {
|
||||
for (KernelScopedObject* p : kernel_objects_) {
|
||||
delete p;
|
||||
}
|
||||
}
|
||||
|
||||
KernelScopedObject::KernelScopedObject() {
|
||||
KernelArena* kernel = KernelArena::GetCurrentKernelArena();
|
||||
kernel->kernel_objects_.push_back(this);
|
||||
}
|
||||
|
||||
static std::vector<KernelArena*>& GetKernelArenaStack() {
|
||||
thread_local std::vector<KernelArena*> kernel_arena_stack;
|
||||
return kernel_arena_stack;
|
||||
}
|
||||
|
||||
void KernelArena::SetCurrentKernelArena(KernelArena *new_kernel_arena) {
|
||||
current_arena = new_kernel_arena;
|
||||
}
|
||||
|
||||
KernelArena* KernelArena::GetCurrentKernelArena() {
|
||||
return current_arena;
|
||||
}
|
||||
|
||||
KernelScope::KernelScope() : owning_(true) {
|
||||
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
|
||||
KernelArena::SetCurrentKernelArena(new KernelArena);
|
||||
}
|
||||
|
||||
KernelScope::KernelScope(KernelArena* arena_) : owning_(false) {
|
||||
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
|
||||
KernelArena::SetCurrentKernelArena(arena_);
|
||||
}
|
||||
|
||||
KernelScope::~KernelScope() {
|
||||
if (owning_) {
|
||||
delete KernelArena::GetCurrentKernelArena();
|
||||
}
|
||||
KernelArena::SetCurrentKernelArena(old_kernel_arena_);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
61
torch/csrc/jit/tensorexpr/mem_arena.h
Normal file
61
torch/csrc/jit/tensorexpr/mem_arena.h
Normal file
@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
class KernelScopedObject;
|
||||
|
||||
// An arena that manages all the underlying kernel-scoped objects.
|
||||
class KernelArena {
|
||||
public:
|
||||
static KernelArena* GetCurrentKernelArena();
|
||||
static void SetCurrentKernelArena(KernelArena* new_arena);
|
||||
TORCH_API KernelArena() {}
|
||||
TORCH_API ~KernelArena();
|
||||
|
||||
private:
|
||||
KernelArena(const KernelArena&) = delete;
|
||||
KernelArena& operator=(const KernelArena&) = delete;
|
||||
friend class KernelScopedObject;
|
||||
std::vector<KernelScopedObject*> kernel_objects_; // owned
|
||||
};
|
||||
|
||||
// A RAII convenience wrapper on top of a kernel.
|
||||
// It either creates or takes an existing Kernel and sets it as the current
|
||||
// Kernel. When this object is destroyed, the previous Kernel is set as current,
|
||||
// and the created kernel is freed. If the kernel was passed, it stays alive.
|
||||
class KernelScope {
|
||||
public:
|
||||
TORCH_API KernelScope();
|
||||
TORCH_API explicit KernelScope(KernelArena* arena_);
|
||||
TORCH_API ~KernelScope();
|
||||
|
||||
private:
|
||||
KernelScope(const KernelScope&) = delete;
|
||||
KernelScope& operator=(const KernelScope&) = delete;
|
||||
KernelArena* kernel_arena_ = nullptr; // arena to be used in this scope
|
||||
KernelArena* old_kernel_arena_ =
|
||||
nullptr; // previous arena, will be restored in destructor
|
||||
bool owning_ = false; // determines whether the arena will be freed along with
|
||||
// the scope object
|
||||
};
|
||||
|
||||
// The base object managed by the Kernel.
|
||||
// The object must be created through "new", and when the Kernel is destroyed,
|
||||
// All its registered objects are destroyed through "delete".
|
||||
class TORCH_API KernelScopedObject {
|
||||
public:
|
||||
KernelScopedObject();
|
||||
virtual ~KernelScopedObject() = default;
|
||||
|
||||
private:
|
||||
KernelScopedObject(const KernelScopedObject&) = delete;
|
||||
KernelScopedObject& operator=(const KernelScopedObject&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user