[SR] Scope exit guard for memory planner deallocation (#68795)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68795

This change improves static runtime exception safety. Added a scope exit guard that invokes `MemoryPlanner::deallocate` in its destructor.

Caveat: we have to be really careful with the exception behavior of `MemoryPlanner::deallocate` and `MemoryPlanner`'s constructor, because they're now both potentially called in the destructor of the scope exit guard. Letting exceptions potentially escape destructors is playing with fire since 1) the destructor of `Deallocator` is (implicitly) `noexcept`, 2) even if it wasn't, `std::terminate` will be called if an exception escapes and the stack is already unwinding. To get around this, we wrap the deallocation stuff in a try/catch. If deallocation throws, then we simply reset all of the memory planner stuff and carry on.
There's a catch: the code path that we take when handling the deallocation exception can't throw. However, this code path is much simpler than memory planner construction/deallocation, so it's much easier to manually audit the correctness here.

Test Plan:
**New unit tests**

`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest`

Reviewed By: hlu1

Differential Revision: D32609915

fbshipit-source-id: 71fbe6994fd573ca6b7dd859b2e6fbd7eeabcd9e
This commit is contained in:
Mike Iovine
2021-12-08 16:40:11 -08:00
committed by Facebook GitHub Bot
parent 3b27304d20
commit 1c43b1602c
6 changed files with 257 additions and 51 deletions

View File

@ -1,9 +1,11 @@
#include <ATen/core/dispatch/OperatorOptions.h>
#include <c10/core/ScalarType.h>
#include <gtest/gtest.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <stdexcept>
#include "deep_wide_pt.h"
#include "test_utils.h"
@ -2180,3 +2182,78 @@ TEST(StaticRuntime, Split) {
testStaticRuntime(src, {a, 1, 1});
testStaticRuntime(src, {a, 2, -1}, {b, 2, 2});
}
namespace {
void maybe_throw(bool should_throw) {
if (should_throw) {
throw std::runtime_error("test exception");
}
}
TORCH_LIBRARY(static_runtime_tests, m) {
// Conservative so this op doesn't get deleted by dead
// code elimination
m.def(torch::schema(
"static_runtime_tests::maybe_throw(bool throw) -> ()",
at::AliasAnalysisKind::CONSERVATIVE));
m.impl("maybe_throw", maybe_throw);
}
} // namespace
TEST(StaticRuntime, ModelCrashOnFirstRun) {
const auto src = R"JIT(
graph(%0: Tensor, %throw: bool):
%1: Tensor = aten::mul(%0, %0)
static_runtime_tests::maybe_throw(%throw)
%2: Tensor = aten::mul(%1, %1)
%3: Tensor = aten::mul(%2, %2)
return (%3)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args_crash{at::randn({1}), true};
std::vector<IValue> args_no_crash{at::randn({1}), false};
EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
// The run didn't finish, we didn't allocate the memory planner
EXPECT_EQ(runtime.get_memory_planner(), nullptr);
runtime.check_for_memory_leak();
// We guarantee that the runtime is still usable after the crash.
// Run again to verify this.
compareResultsWithJIT(runtime, graph, args_no_crash);
EXPECT_NE(runtime.get_memory_planner(), nullptr);
}
TEST(StaticRuntime, ModelCrashOnSecondRun) {
const auto src = R"JIT(
graph(%0: Tensor, %throw: bool):
%1: Tensor = aten::mul(%0, %0)
static_runtime_tests::maybe_throw(%throw)
%2: Tensor = aten::mul(%1, %1)
%3: Tensor = aten::mul(%2, %2)
return (%3)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args_crash{at::randn({1}), true};
std::vector<IValue> args_no_crash{at::randn({1}), false};
runtime(args_no_crash, {});
EXPECT_NE(runtime.get_memory_planner(), nullptr);
runtime.check_for_memory_leak();
EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
runtime.check_for_memory_leak();
// We guarantee that the runtime is still usable after the crash.
// Run again to verify this.
compareResultsWithJIT(runtime, graph, args_no_crash);
}

View File

@ -22,6 +22,27 @@ namespace test {
namespace {
class GraphExecutorWrapper {
public:
GraphExecutorWrapper() = default;
explicit GraphExecutorWrapper(const std::shared_ptr<Graph>& graph)
: graph_exec_(graph, "") {}
c10::IValue operator()(const std::vector<c10::IValue>& args) {
Stack stack(args);
graph_exec_.run(stack);
if (stack.size() == 1) {
return stack[0];
}
return c10::ivalue::Tuple::create(stack);
}
private:
GraphExecutor graph_exec_;
};
// Test scripts passed to testStaticRuntime can either be IR or JIT.
// The logic for running the script and producing a corresponding StaticModule
// is a bit different for each case. This logic is encapsulated within concrete
@ -62,17 +83,11 @@ class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
std::unordered_map<std::string, Value*> vmap;
parseIR(source_ir, graph_.get(), vmap);
graph_exec_ = GraphExecutor(graph_, "");
graph_exec_ = GraphExecutorWrapper(graph_);
}
IValue getExpected(const std::vector<IValue>& args) override {
Stack stack(args);
graph_exec_.run(stack);
if (stack.size() == 1) {
return stack[0];
}
return c10::ivalue::Tuple::create(stack);
return graph_exec_(args);
}
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
@ -81,7 +96,7 @@ class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
private:
std::shared_ptr<Graph> graph_;
GraphExecutor graph_exec_;
GraphExecutorWrapper graph_exec_;
};
std::unique_ptr<StaticRuntimeTestContext> makeTestContext(
@ -208,6 +223,19 @@ std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
return graph;
}
void compareResultsWithJIT(
StaticRuntime& runtime,
const std::shared_ptr<Graph>& graph,
const std::vector<c10::IValue>& args,
const bool use_allclose,
const bool use_equalnan) {
GraphExecutorWrapper graph_exec(graph);
auto expected = graph_exec(args);
auto actual = runtime(args, {});
runtime.check_for_memory_leak();
compareResults(expected, actual, use_allclose, use_equalnan);
}
void testStaticRuntime(
const std::string& source,
const std::vector<IValue>& args,

View File

@ -42,6 +42,13 @@ Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
void compareResultsWithJIT(
StaticRuntime& runtime,
const std::shared_ptr<Graph>& graph,
const std::vector<c10::IValue>& args,
const bool use_allclose = false,
const bool use_equalnan = false);
} // namespace test
} // namespace jit
} // namespace torch

View File

@ -5,6 +5,8 @@
#include <ATen/record_function.h>
#include <c10/core/CPUAllocator.h>
#include <c10/core/InferenceMode.h>
#include <c10/macros/Macros.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/irange.h>
#include <caffe2/core/scope_guard.h>
#include <caffe2/core/timer.h>
@ -841,6 +843,40 @@ void StaticRuntime::create_memory_planner() {
}
}
namespace {
void destroyNodeOutputs(ProcessedNode& p_node) {
const auto borrows_outputs = borrowsOutputs(p_node.node()->kind());
for (const auto i : c10::irange(p_node.num_outputs())) {
auto& output = p_node.Output(i);
if (doesNotHeapAllocateWhenStoredInIValue(*output.type())) {
continue;
}
if (borrows_outputs) {
// NB: No need to incref here. This codepath is only hit if the run didn't
// finish, so we shouldn't be returning anything to the client.
c10::MaybeOwnedTraits<IValue>::destroyBorrow(output);
} else {
output = IValue();
}
}
}
} // namespace
void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
for (auto& p_node : nodes_) {
destroyNodeOutputs(p_node);
}
}
void StaticRuntime::resetMemory() noexcept {
planner_.reset();
clean_up_input_ivalues();
clean_up_intermediate_ivalues();
}
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
#ifndef NDEBUG
for (const auto i : c10::irange(num_outputs)) {
@ -973,6 +1009,38 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
}
}
StaticRuntime::Deallocator::~Deallocator() {
// Assume cleanup cannot throw.
cleanupImpl();
#ifndef NDEBUG
runtime_.check_for_memory_leak(false);
#endif
}
void StaticRuntime::Deallocator::cleanupImpl() {
if (runtime_.static_module_.opts().cleanup_activations) {
// MemoryPlanner is created after the first invocation of `run()`. This
// is done intentionally because MemoryPlanner uses `Tensor` sizes of
// the previous `run()` for memory planning of subsequent runs
if (C10_LIKELY(finished_)) {
runtime_.create_memory_planner();
}
if (C10_LIKELY(runtime_.planner_)) {
runtime_.planner_->deallocate();
} else {
// This is the first run, and it didn't finish, so we can't use a
// `MemoryPlanner` to deallocate stuff. Just reset everything mannually.
runtime_.resetMemory();
}
// clean up owning refs of input tensors
runtime_.clean_up_input_ivalues();
if (C10_UNLIKELY(!finished_)) {
runtime_.deallocateOutputTensors();
}
}
}
template <typename IValueList>
c10::IValue StaticRuntime::run_impl(
IValueList&& args,
@ -983,37 +1051,29 @@ c10::IValue StaticRuntime::run_impl(
// functions, such as resize_ and resize_as_.
c10::InferenceMode mode;
if (planner_) {
DCHECK(!manage_output_tensors_enabled_ || checkOutputTensorMemoryLeaks());
planner_->allocate();
}
{
auto on_exit = Deallocator(*this);
set_inputs(std::forward<IValueList>(args), kwargs);
if (planner_) {
DCHECK(!manage_output_tensors_enabled_ || checkOutputTensorMemoryLeaks());
planner_->allocate();
}
for (auto& n : nodes_) {
// LOG(INFO) << "Running node: " << PrintNode(n.node());
n.run();
// Check for incorrect schema alias info.
verify_and_correct_memory_overlap(n);
}
set_inputs(std::forward<IValueList>(args), kwargs);
if (static_module_.opts().cleanup_activations) {
// MemoryPlanner is created after the first invocation of `run()`. This is
// done intentionally because MemoryPlanner uses `Tensor` sizes of the
// previous `run()` for memory planning of subsequent runs
create_memory_planner();
planner_->deallocate();
// clean up owning refs of input tensors
clean_up_input_ivalues();
for (auto& n : nodes_) {
// LOG(INFO) << "Running node: " << PrintNode(n.node());
n.run();
// Check for incorrect schema alias info.
verify_and_correct_memory_overlap(n);
}
on_exit.setFinished();
}
// no need to keep references of outputs in static runtime anymore
if (static_module_.num_outputs() > 1) {
return move_outputs_to_tuple(static_module_.num_outputs());
}
#ifndef NDEBUG
check_for_memory_leak(false);
#endif
// The exact output tensor should never be managed.
DCHECK(!isManagedOutputTensor(*outputs_[0]));
// use move here. Otherwise, clean up outputs_[0] explicitly
@ -1272,6 +1332,9 @@ void StaticRuntime::display_nodes(
const std::vector<c10::IValue>& args,
const KeywordArgs& kwargs) {
c10::InferenceMode mode;
auto on_exit = Deallocator(*this);
if (planner_) {
planner_->allocate();
}
@ -1281,16 +1344,7 @@ void StaticRuntime::display_nodes(
node.run();
display_pnode_info(node);
}
if (static_module_.opts().cleanup_activations) {
// MemoryPlanner is created after the first invocation of `run()`. This is
// done intentionally because MemoryPlanner uses `Tensor` sizes of the
// previous `run()` for memory planning of subsequent runs
create_memory_planner();
planner_->deallocate();
// clean up owning refs of input tensors
clean_up_input_ivalues();
}
on_exit.setFinished();
}
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(

View File

@ -56,6 +56,17 @@ TORCH_API inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) {
}
}
TORCH_API inline bool borrowsOutputs(c10::Symbol kind) {
static const std::array<c10::Symbol, 2> symbols_with_borrowed_outputs = {
c10::Symbol::fromQualString("static_runtime::dict_unpack"),
c10::Symbol::fromQualString("static_runtime::VarTupleUnpack"),
};
return std::find(
symbols_with_borrowed_outputs.begin(),
symbols_with_borrowed_outputs.end(),
kind) != symbols_with_borrowed_outputs.end();
}
// Group values used by `graph` into three categories:
//
// - output_aliases:
@ -456,7 +467,41 @@ class TORCH_API StaticRuntime {
void disableManageOutputTensors();
// This is the fallback path taken if we can't construct the memory planner
// on the first iteration.
// IMPORTANT: Nothing here should be able to throw!!!
// This function can be called from the (implicitly) `noexcept` destructor
// of Deallocator, meaning that std::terminate will be called
// if any exception escapes. Even if resetMemory and ~Deallocator were
// `noexcept(false)`, it's possible that when ~Deallocator is called, the
// stack is already unwinding, so there's still danger of calling
// std::terminate.
void resetMemory() noexcept;
private:
// A helper object that invokes memory planner deallocation code
// when destructed.
class Deallocator {
public:
explicit Deallocator(StaticRuntime& runtime) : runtime_(runtime) {}
Deallocator(Deallocator&&) = default;
Deallocator(const Deallocator&) = default;
Deallocator& operator=(const Deallocator&) = delete;
Deallocator& operator=(Deallocator&&) = delete;
~Deallocator();
void setFinished() {
finished_ = true;
}
private:
void cleanupImpl();
bool finished_ = false;
StaticRuntime& runtime_;
};
template <typename IValueList>
c10::IValue run_impl(IValueList&& args, const KeywordArgs& kwargs);
@ -474,12 +519,14 @@ class TORCH_API StaticRuntime {
void verify_and_correct_memory_overlap(ProcessedNode& n);
// clean up owning refs of input IValues
void clean_up_input_ivalues() {
void clean_up_input_ivalues() noexcept {
for (const auto idx : c10::irange(static_module_.num_inputs())) {
values_[idx] = IValue();
}
}
void clean_up_intermediate_ivalues() noexcept;
IValue move_outputs_to_tuple(uint32_t num_outputs);
void create_memory_planner();

View File

@ -162,6 +162,7 @@ MemoryPlanner::MemoryPlanner(
FastSet<IValue*> unmanaged_ivalues;
FastSet<IValue*> unmanaged_borrowed_ivalues;
for (ProcessedNode& pnode : runtime->nodes()) {
const auto borrows_outputs = borrowsOutputs(pnode.node()->kind());
for (const auto i : c10::irange(pnode.outputs().size())) {
const Value* out_v = pnode.node()->outputs()[i];
const bool in_managed_sets = setIncludes(managed_tensor_values, out_v) ||
@ -174,18 +175,10 @@ MemoryPlanner::MemoryPlanner(
if (in_managed_sets && !isUnmanagedSpecialCase(pnode, i)) {
continue;
}
static const std::array<c10::Symbol, 2> symbols_with_borrowed_outputs = {
c10::Symbol::fromQualString("static_runtime::dict_unpack"),
c10::Symbol::fromQualString("static_runtime::VarTupleUnpack"),
};
if (doesNotHeapAllocateWhenStoredInIValue(*out_v->type())) {
// Scalars do not need to be freed after each iteration.
num_unmanaged_scalar_ivalues_++;
} else if (
std::find(
symbols_with_borrowed_outputs.begin(),
symbols_with_borrowed_outputs.end(),
pnode.node()->kind()) != symbols_with_borrowed_outputs.end()) {
} else if (borrows_outputs) {
IValue& out = pnode.Output(i);
unmanaged_borrowed_ivalues.insert(&out);
} else {