[SR] verify_and_correct_memory_overlap handles tensor lists (#69774)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69774

We recently ran into a nasty bug caused by incorrect schema annotations on an `aten::split` overload. `verify_and_correct_memory_overlap` is supposed to prevent crashes in this scenario, but it didn't because it did not handle `Tensor[]` outputs.

This change extends the memory correction mechanism to handle tensor lists.
ghstack-source-id: 146152478

Test Plan: `buck test caffe2/benchmarks/static_runtime/...`

Reviewed By: hlu1

Differential Revision: D33022494

fbshipit-source-id: 8d1d41ca1d4fd5dfb7c8a66028c391ba63551eb0
This commit is contained in:
Mike Iovine
2021-12-22 17:16:42 -08:00
committed by Facebook GitHub Bot
parent 385c12852e
commit 682fab19d4
2 changed files with 72 additions and 13 deletions

View File

@ -29,6 +29,7 @@
#include <stdexcept>
#ifdef FBCODE_CAFFE2
#include <common/logging/logging.h>
#include <folly/dynamic.h>
#include <folly/json.h>
#endif
@ -282,10 +283,19 @@ void ValueGroup::init(
namespace {
bool isTensorList(const Value* value) {
auto* type = value->type()->castRaw<ListType>();
if (!type) {
return false;
}
return type->getElementType()->kind() == c10::TypeKind::TensorType;
}
bool containTensorsOnly(at::ArrayRef<Value*> values) {
// return true only if all outputs are tensors
return std::all_of(values.begin(), values.end(), [](const Value* value) {
return value->type()->castRaw<TensorType>() != nullptr;
return value->type()->kind() == c10::TypeKind::TensorType ||
isTensorList(value);
});
}
@ -1023,13 +1033,19 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
} else if (planner_) {
bool overlap_detected_with_fast_check = false;
for (size_t i = 0; i < n.outputs().size(); i++) {
at::Tensor& t = n.Output(i).toTensor();
if (planner_->overlapWithInternalBuffer(t.data_ptr())) {
DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
n.Output(i) = at::native::clone(t, c10::nullopt);
// set flag if overlap detected
overlap_detected_with_fast_check = true;
n.set_outputs_memory_overlap_detected();
auto& output = n.Output(i);
if (output.isTensor()) {
overlap_detected_with_fast_check |=
fast_check_and_correct_overlap_with(n, output);
} else if (output.isTensorList()) {
auto tensor_list = output.toListRef();
for (auto& ival : tensor_list) {
overlap_detected_with_fast_check |=
fast_check_and_correct_overlap_with(
n,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<c10::IValue&>(ival));
}
}
}
if (n.outputs_memory_overlap_detected() &&
@ -1041,6 +1057,19 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
}
}
bool StaticRuntime::fast_check_and_correct_overlap_with(
ProcessedNode& n,
c10::IValue& tensor_ival) {
auto& tensor = tensor_ival.toTensor();
if (planner_->overlapWithInternalBuffer(tensor.data_ptr())) {
DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
tensor_ival = at::native::clone(tensor, c10::nullopt);
n.set_outputs_memory_overlap_detected();
return true;
}
return false;
}
StaticRuntime::Deallocator::~Deallocator() {
// Assume cleanup cannot throw.
cleanupImpl();
@ -1451,6 +1480,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
nodes_[k].run();
millis = timer.MilliSeconds();
results.time_per_node[k] += millis;
verify_and_correct_memory_overlap(nodes_[k]);
}
timer.Start();
if (static_module_.opts().cleanup_activations) {
@ -1840,6 +1870,19 @@ bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
return true;
}
bool ProcessedNode::check_and_correct_overlap_with(
const at::Tensor& input,
c10::IValue& output_ival) {
auto& tensor = output_ival.toTensor();
if (!checkNoMemoryOverlap(input, tensor)) {
DLOG(INFO) << "Detected alias for node: " << PrintNode(node());
output_ival = at::native::clone(tensor, c10::nullopt);
set_outputs_memory_overlap_detected();
return true;
}
return false;
}
void ProcessedNode::verify_and_correct_memory_overlap() {
for (const auto i : c10::irange(inputs_.size())) {
const IValue& in = Input(i);
@ -1848,11 +1891,21 @@ void ProcessedNode::verify_and_correct_memory_overlap() {
}
const auto& in_t = in.toTensor();
for (const auto j : c10::irange(num_outputs_)) {
const auto& out_t = Output(j).toTensor();
if (!checkNoMemoryOverlap(in_t, out_t)) {
DLOG(INFO) << "Detected alias for node: " << PrintNode(node());
Output(i) = at::native::clone(out_t, c10::nullopt);
set_outputs_memory_overlap_detected();
auto& output = Output(j);
if (output.isTensor()) {
check_and_correct_overlap_with(in_t, output);
} else if (output.isTensorList()) {
auto tensors = output.toListRef();
for (const auto& ival : tensors) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
check_and_correct_overlap_with(in_t, const_cast<c10::IValue&>(ival));
}
#ifdef FBCODE_CAFFE2
if (outputs_memory_overlap_detected()) {
LOG_EVERY_MS(WARNING, 60000)
<< "Detected alias for node: " << PrintNode(node());
}
#endif
}
}
}

View File

@ -560,6 +560,9 @@ class TORCH_API StaticRuntime {
// Set Input(idx) to arg. Always copies. Used for kwargs.
void set_arg(const size_t idx, const IValue& arg);
bool fast_check_and_correct_overlap_with(
ProcessedNode& n,
c10::IValue& tensor_ival);
void verify_and_correct_memory_overlap(ProcessedNode& n);
// clean up owning refs of input IValues
@ -719,6 +722,9 @@ class TORCH_API ProcessedNode {
return overlap_detected_;
}
bool check_and_correct_overlap_with(
const at::Tensor& input,
c10::IValue& output);
void verify_and_correct_memory_overlap();
void set_values(IValue* values) {