[Static Runtime] Fallback to disabling manage_output_tensors instead of crashing when wrong API is used (#67939)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67939

With `manage_output_tensor` enabled, a client of `StaticRuntime` requires to call it via  `PyTorchPredictor::predict_managed_result`. If the client uses `PyTorchPredictor::operator()`  the client will experience a crash (intended behavior not to  leak memory of managed output tensors). This mistake can cause a catastrophic failure in production if that happens (by gatekeeper, config changes, etc).

Considering the complexity in how `PyTorchPredictor` is used in different settings, the chances that this bug can hit production is non-zero.

This change introduces `StaticRuntime::disableManageOutputTensor` to disable `manage_output_tensor` feature when a client mistakenly uses `PyTorchPredictor::operator()` instead of crashing. When `StaticRuntime` is invoked via `PyTorchPredictor::operator()`, it first calls  `StaticRuntime::disableManageOutputTensor` to disable the feature, so that it can get non-managed output tensors to pass to the client safely.

A slight perf degradation is expected by forcefully disabling `manage_output_tensors`, but its robustness value outweighs a catastrophic failure of crashes at a high rate.

Test Plan: Added a unittest `StaticRuntime, DisableManageOutputTensors` to cover the newly added code.

Reviewed By: swolchok

Differential Revision: D32219731

fbshipit-source-id: caf5c910b34726c570e17435ede7d888443e90cf
This commit is contained in:
Don Jang
2021-11-11 17:28:58 -08:00
committed by Facebook GitHub Bot
parent 3dc0754c53
commit 9cb65df79f
3 changed files with 110 additions and 8 deletions

View File

@ -674,6 +674,88 @@ TEST(StaticRuntime, ManageOutputTensorsWithoutDeallocateOutputTensors) {
runtime(input_tensors, {});
}
TEST(StaticRuntime, DisableManageOutputTensors) {
const std::string test_graph = R"IR(
graph(%0 : Tensor):
# With manage_output_tensor enabled, this tensor is managed.
%1 : Tensor = aten::abs(%0)
# The output container object is never managed.
%2 : (Tensor) = prim::TupleConstruct(%1)
return (%2)
)IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(test_graph, g.get());
torch::jit::StaticModuleOptions opts{
/*cleanup_activations=*/true,
/*enable_out_variant=*/true,
/*optimize_memory=*/true,
/*manage_output_tensors=*/true};
auto a = at::randn({2, 2});
std::vector<at::IValue> args{a};
torch::jit::StaticModule smod(g, opts);
torch::jit::StaticRuntime runtime(smod);
// Profile run.
{
IValue tuple = runtime(args, {});
IValue element = tuple.toTupleRef().elements()[0];
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
tuple = IValue();
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
// Second run that manages output tensors.
{
IValue tuple = runtime(args, {});
IValue element = tuple.toTupleRef().elements()[0];
EXPECT_TRUE(runtime.isManagedOutputTensor(element));
tuple = IValue();
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
// Reset the runtime and start profiling again.
runtime.disableManageOutputTensors();
IValue copied_output_tensor;
IValue original_output_tensor;
// New profile run.
{
IValue tuple = runtime(args, {});
IValue element = tuple.toTupleRef().elements()[0];
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
copied_output_tensor = element.deepcopy();
original_output_tensor = element;
tuple = IValue();
// No-op since manage_output_tensor is disabled now.
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
// Ensure that `original_output_tensor` is no longer managed: even after
// calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
// contains a valid value.
EXPECT_TRUE(
original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
// Ensure that the second optimized run does not manage the output tensor
// either.
{
IValue tuple = runtime(args, {});
IValue element = tuple.toTupleRef().elements()[0];
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
copied_output_tensor = element.deepcopy();
original_output_tensor = element;
tuple = IValue();
// No-op since manage_output_tensor is disabled now.
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
// Ensure that `original_output_tensor` is no longer managed: even after
// calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
// contains a valid value.
EXPECT_TRUE(
original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
}
TEST(StaticRuntime, FusionPass) {
const int embedding_size = 32;
const int num_features = 50;

View File

@ -803,7 +803,9 @@ c10::IValue StaticModule::operator()(
return runtime()(std::move(args), kwargs);
}
StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) {
StaticRuntime::StaticRuntime(const StaticModule& sm)
: static_module_(sm),
manage_output_tensors_enabled_(sm.opts().manage_output_tensors) {
// NB: create unchanging std::vector<IValue>s we can reference
inputs_.resize(sm.num_inputs());
nodes_.resize(sm.nodes().size());
@ -938,7 +940,7 @@ void StaticRuntime::create_memory_planner() {
static_module_.values_share_same_storage(),
static_module_.value_group(),
static_module_.opts().enable_out_variant,
static_module_.opts().manage_output_tensors);
manage_output_tensors_enabled_);
}
}
@ -1085,9 +1087,7 @@ c10::IValue StaticRuntime::run_impl(
c10::InferenceMode mode;
if (planner_) {
DCHECK(
!static_module_.opts().manage_output_tensors ||
checkOutputTensorMemoryLeaks());
DCHECK(!manage_output_tensors_enabled_ || checkOutputTensorMemoryLeaks());
planner_->allocate();
}
@ -1271,12 +1271,11 @@ float StaticRuntime::benchmark_model(
const bool is_kwargs_empty = kwargs_list.size() == 0;
const std::unordered_map<std::string, c10::IValue> empty_kwargs;
bool manage_output_tensors = static_module_.opts().manage_output_tensors;
for (const auto i : c10::irange(warmup_runs)) {
(void)i; // Suppress unused variable warning
for (const auto j : c10::irange(args_list.size())) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
if (manage_output_tensors) {
if (manage_output_tensors_enabled_) {
deallocateOutputTensors();
}
}
@ -1286,7 +1285,7 @@ float StaticRuntime::benchmark_model(
(void)i; // Suppress unused variable warning
for (const auto j : c10::irange(args_list.size())) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
if (manage_output_tensors) {
if (manage_output_tensors_enabled_) {
deallocateOutputTensors();
}
}
@ -1600,6 +1599,24 @@ bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) {
return planner_ && planner_->isManagedOutputTensor(ivalue);
}
void StaticRuntime::disableManageOutputTensors() {
if (!manage_output_tensors_enabled_) {
return;
}
manage_output_tensors_enabled_ = false;
if (!planner_) {
return;
}
// Reset all IValues and destruct planner_ so that it can be reconstructed in
// the next run.
for (auto& n : nodes_) {
for (const auto i : c10::irange(n.outputs().size())) {
n.Output(i) = IValue();
}
}
planner_.reset();
}
ProcessedNode::ProcessedNode(
Node* node,
std::unique_ptr<const IValue*[]> inputs,

View File

@ -358,6 +358,8 @@ class TORCH_API StaticRuntime {
bool isManagedOutputTensor(const IValue& ivalue);
void disableManageOutputTensors();
private:
template <typename IValueList>
c10::IValue run_impl(
@ -400,6 +402,7 @@ class TORCH_API StaticRuntime {
// Otherwise, the memory used by activations is cached inside the static
// runtime.
const StaticModule& static_module_;
bool manage_output_tensors_enabled_ = false;
std::unique_ptr<MemoryPlanner> planner_;
std::vector<IValue> inputs_;
std::vector<IValue*> outputs_;