mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SR] verify_and_correct_memory_overlap handles tensor lists (#69774)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69774 We recently ran into a nasty bug caused by incorrect schema annotations on an `aten::split` overload. `verify_and_correct_memory_overlap` is supposed to prevent crashes in this scenario, but it didn't because it did not handle `Tensor[]` outputs. This change extends the memory correction mechanism to handle tensor lists. ghstack-source-id: 146152478 Test Plan: `buck test caffe2/benchmarks/static_runtime/...` Reviewed By: hlu1 Differential Revision: D33022494 fbshipit-source-id: 8d1d41ca1d4fd5dfb7c8a66028c391ba63551eb0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
385c12852e
commit
682fab19d4
@ -29,6 +29,7 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#ifdef FBCODE_CAFFE2
|
||||
#include <common/logging/logging.h>
|
||||
#include <folly/dynamic.h>
|
||||
#include <folly/json.h>
|
||||
#endif
|
||||
@ -282,10 +283,19 @@ void ValueGroup::init(
|
||||
|
||||
namespace {
|
||||
|
||||
bool isTensorList(const Value* value) {
|
||||
auto* type = value->type()->castRaw<ListType>();
|
||||
if (!type) {
|
||||
return false;
|
||||
}
|
||||
return type->getElementType()->kind() == c10::TypeKind::TensorType;
|
||||
}
|
||||
|
||||
bool containTensorsOnly(at::ArrayRef<Value*> values) {
|
||||
// return true only if all outputs are tensors
|
||||
return std::all_of(values.begin(), values.end(), [](const Value* value) {
|
||||
return value->type()->castRaw<TensorType>() != nullptr;
|
||||
return value->type()->kind() == c10::TypeKind::TensorType ||
|
||||
isTensorList(value);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1023,13 +1033,19 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
|
||||
} else if (planner_) {
|
||||
bool overlap_detected_with_fast_check = false;
|
||||
for (size_t i = 0; i < n.outputs().size(); i++) {
|
||||
at::Tensor& t = n.Output(i).toTensor();
|
||||
if (planner_->overlapWithInternalBuffer(t.data_ptr())) {
|
||||
DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
|
||||
n.Output(i) = at::native::clone(t, c10::nullopt);
|
||||
// set flag if overlap detected
|
||||
overlap_detected_with_fast_check = true;
|
||||
n.set_outputs_memory_overlap_detected();
|
||||
auto& output = n.Output(i);
|
||||
if (output.isTensor()) {
|
||||
overlap_detected_with_fast_check |=
|
||||
fast_check_and_correct_overlap_with(n, output);
|
||||
} else if (output.isTensorList()) {
|
||||
auto tensor_list = output.toListRef();
|
||||
for (auto& ival : tensor_list) {
|
||||
overlap_detected_with_fast_check |=
|
||||
fast_check_and_correct_overlap_with(
|
||||
n,
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<c10::IValue&>(ival));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (n.outputs_memory_overlap_detected() &&
|
||||
@ -1041,6 +1057,19 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
|
||||
}
|
||||
}
|
||||
|
||||
bool StaticRuntime::fast_check_and_correct_overlap_with(
|
||||
ProcessedNode& n,
|
||||
c10::IValue& tensor_ival) {
|
||||
auto& tensor = tensor_ival.toTensor();
|
||||
if (planner_->overlapWithInternalBuffer(tensor.data_ptr())) {
|
||||
DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
|
||||
tensor_ival = at::native::clone(tensor, c10::nullopt);
|
||||
n.set_outputs_memory_overlap_detected();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
StaticRuntime::Deallocator::~Deallocator() {
|
||||
// Assume cleanup cannot throw.
|
||||
cleanupImpl();
|
||||
@ -1451,6 +1480,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
nodes_[k].run();
|
||||
millis = timer.MilliSeconds();
|
||||
results.time_per_node[k] += millis;
|
||||
verify_and_correct_memory_overlap(nodes_[k]);
|
||||
}
|
||||
timer.Start();
|
||||
if (static_module_.opts().cleanup_activations) {
|
||||
@ -1840,6 +1870,19 @@ bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ProcessedNode::check_and_correct_overlap_with(
|
||||
const at::Tensor& input,
|
||||
c10::IValue& output_ival) {
|
||||
auto& tensor = output_ival.toTensor();
|
||||
if (!checkNoMemoryOverlap(input, tensor)) {
|
||||
DLOG(INFO) << "Detected alias for node: " << PrintNode(node());
|
||||
output_ival = at::native::clone(tensor, c10::nullopt);
|
||||
set_outputs_memory_overlap_detected();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ProcessedNode::verify_and_correct_memory_overlap() {
|
||||
for (const auto i : c10::irange(inputs_.size())) {
|
||||
const IValue& in = Input(i);
|
||||
@ -1848,11 +1891,21 @@ void ProcessedNode::verify_and_correct_memory_overlap() {
|
||||
}
|
||||
const auto& in_t = in.toTensor();
|
||||
for (const auto j : c10::irange(num_outputs_)) {
|
||||
const auto& out_t = Output(j).toTensor();
|
||||
if (!checkNoMemoryOverlap(in_t, out_t)) {
|
||||
DLOG(INFO) << "Detected alias for node: " << PrintNode(node());
|
||||
Output(i) = at::native::clone(out_t, c10::nullopt);
|
||||
set_outputs_memory_overlap_detected();
|
||||
auto& output = Output(j);
|
||||
if (output.isTensor()) {
|
||||
check_and_correct_overlap_with(in_t, output);
|
||||
} else if (output.isTensorList()) {
|
||||
auto tensors = output.toListRef();
|
||||
for (const auto& ival : tensors) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
check_and_correct_overlap_with(in_t, const_cast<c10::IValue&>(ival));
|
||||
}
|
||||
#ifdef FBCODE_CAFFE2
|
||||
if (outputs_memory_overlap_detected()) {
|
||||
LOG_EVERY_MS(WARNING, 60000)
|
||||
<< "Detected alias for node: " << PrintNode(node());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -560,6 +560,9 @@ class TORCH_API StaticRuntime {
|
||||
// Set Input(idx) to arg. Always copies. Used for kwargs.
|
||||
void set_arg(const size_t idx, const IValue& arg);
|
||||
|
||||
bool fast_check_and_correct_overlap_with(
|
||||
ProcessedNode& n,
|
||||
c10::IValue& tensor_ival);
|
||||
void verify_and_correct_memory_overlap(ProcessedNode& n);
|
||||
|
||||
// clean up owning refs of input IValues
|
||||
@ -719,6 +722,9 @@ class TORCH_API ProcessedNode {
|
||||
return overlap_detected_;
|
||||
}
|
||||
|
||||
bool check_and_correct_overlap_with(
|
||||
const at::Tensor& input,
|
||||
c10::IValue& output);
|
||||
void verify_and_correct_memory_overlap();
|
||||
|
||||
void set_values(IValue* values) {
|
||||
|
Reference in New Issue
Block a user