Record view stacks if running anomaly mode (#103185)

Now, when you do an inplace mutation and the view is naughty, you get this message:

```
RuntimeError: A view was created in no_grad mode and is being modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked). To find out where this view was allocated, run your entire forward region under anomaly mode (torch.autograd.detect_anomaly(check_nan=False)).
```

When you run under anomaly mode, you get:

```
RuntimeError: A view was created in no_grad mode and is being modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked). This view was allocated at:
  File "/data/users/ezyang/c/pytorch/test/test_autograd.py", line 4299, in arglebargle
  File "/data/users/ezyang/c/pytorch/test/test_autograd.py", line 4306, in test_anomaly_gives_view_stack
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/case.py", line 591, in run
  File "/data/users/ezyang/c/pytorch/torch/testing/_internal/common_utils.py", line 2266, in _run_with_retry
  File "/data/users/ezyang/c/pytorch/torch/testing/_internal/common_utils.py", line 2337, in run
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/case.py", line 650, in __call__
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/suite.py", line 122, in run
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/suite.py", line 84, in __call__
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/suite.py", line 122, in run
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/suite.py", line 84, in __call__
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/runner.py", line 184, in run
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/main.py", line 271, in runTests
  File "/home/ezyang/local/c/pytorch-env/lib/python3.10/unittest/main.py", line 101, in __init__
  File "/data/users/ezyang/c/pytorch/torch/testing/_internal/common_utils.py", line 894, in run_tests
  File "/data/users/ezyang/c/pytorch/test/test_autograd.py", line 11209, in <module>
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103185
Approved by: https://github.com/zdevito
This commit is contained in:
Edward Z. Yang
2023-06-09 10:06:16 -04:00
committed by PyTorch MergeBot
parent 79e0a1eacb
commit a02c573a89
7 changed files with 72 additions and 7 deletions

View File

@ -175,6 +175,8 @@ core_trainer_sources = [
"torch/csrc/jit/ir/type_hashing.cpp",
"torch/csrc/jit/serialization/pickler.cpp",
"torch/csrc/jit/serialization/type_name_uniquer.cpp",
"torch/csrc/profiler/unwind/unwind.cpp",
"torch/csrc/profiler/combined_traceback.cpp",
]
torch_mobile_core = [
@ -403,8 +405,6 @@ core_sources_full_mobile_no_backend_interface_xplat = [
"torch/csrc/jit/tensorexpr/types.cpp",
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
"torch/csrc/jit/testing/file_check.cpp",
"torch/csrc/profiler/unwind/unwind.cpp",
"torch/csrc/profiler/combined_traceback.cpp",
"torch/csrc/jit/testing/hooks_for_testing.cpp",
"torch/csrc/utils/cpp_stacktraces.cpp",
"torch/csrc/utils/schema_info.cpp",

View File

@ -1172,6 +1172,7 @@ def main():
'include/torch/csrc/jit/codegen/cuda/scheduler/*.h',
'include/torch/csrc/onnx/*.h',
'include/torch/csrc/profiler/*.h',
'include/torch/csrc/profiler/unwind/*.h',
'include/torch/csrc/profiler/orchestration/*.h',
'include/torch/csrc/profiler/stubs/*.h',
'include/torch/csrc/utils/*.h',

View File

@ -4293,6 +4293,20 @@ Done""")
out.backward()
self.assertIn('MyFunc.apply', str(w[0].message))
def test_anomaly_gives_view_stack(self):
def arglebargle(x):
with torch.no_grad():
return x.view(2, 2)
r = arglebargle(torch.randn(4))
with self.assertRaisesRegex(RuntimeError, r"detect_anomaly\(check_nan=False\)"):
r.add_(torch.randn(4, requires_grad=True))
with detect_anomaly(check_nan=False):
r = arglebargle(torch.randn(4))
with self.assertRaisesRegex(RuntimeError, "arglebargle"):
r.add_(torch.randn(4, requires_grad=True))
def test_calculate_shape_util(self):
out = torch.randn(10, 5, requires_grad=True)
grad = torch.randn(5, 10, requires_grad=True)

View File

@ -1,6 +1,7 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/InferenceMode.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
@ -40,7 +41,13 @@ DifferentiableViewMeta::DifferentiableViewMeta(
backward_info_(std::move(backward_info)),
forward_info_(std::move(forward_info)),
shared_view_info_(shared_view_info),
creation_meta_(creation_meta) {
creation_meta_(creation_meta),
creation_traceback_(
AnomalyMode::is_enabled() ? torch::CapturedTraceback::gather(
/*python*/ true,
/*script*/ false,
/*cpp*/ false)
: nullptr) {
is_view_ = true;
if (backward_info_.has_value()) {
self_impl->set_version_counter(
@ -59,6 +66,16 @@ DifferentiableViewMeta::DifferentiableViewMeta(
}
}
void DifferentiableViewMeta::set_creation_meta(CreationMeta new_creation_meta) {
TORCH_CHECK(
has_bw_view(), "creation_meta can only exist for backward views.");
creation_meta_ = new_creation_meta;
if (AnomalyMode::is_enabled()) {
creation_traceback_ = torch::CapturedTraceback::gather(
/*python*/ true, /*script*/ false, /*cpp*/ false);
}
}
// Chain this view info with the new view op between base and tensor
ViewInfo ViewInfo::chain(
const Variable& base,
@ -838,6 +855,24 @@ void handle_view_on_rebase(
TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state");
}
auto* tb = diff_view_meta->get_creation_traceback().get();
if (tb) {
std::ostringstream oss;
torch::SymbolizedTracebacks st = torch::symbolize({tb});
const std::vector<uint64_t>& traceback = st.tracebacks[0];
for (uint64_t idx : traceback) {
const unwind::Frame& frame = st.all_frames[idx];
oss << " File \"" << frame.filename << "\", line " << frame.lineno
<< ", in " << frame.funcname << "\n";
}
msg = c10::str(msg, " This view was allocated at:\n", oss.str());
} else {
msg = c10::str(
msg,
" To find out where this view was allocated, run your entire forward region under"
" anomaly mode (torch.autograd.detect_anomaly(check_nan=False)).");
}
TORCH_CHECK(false, msg);
}
}

View File

@ -7,6 +7,7 @@
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/profiler/combined_traceback.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/core/Tensor.h>
@ -594,6 +595,7 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
/// version_counter.current_version().
uint32_t attr_version_;
CreationMeta creation_meta_;
std::shared_ptr<torch::CapturedTraceback> creation_traceback_;
public:
/// requires_grad is a backward AD field so we only use the view specific
@ -635,12 +637,13 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
return creation_meta_;
}
void set_creation_meta(CreationMeta new_creation_meta) {
TORCH_CHECK(
has_bw_view(), "creation_meta can only exist for backward views.");
creation_meta_ = new_creation_meta;
const std::shared_ptr<torch::CapturedTraceback>& get_creation_traceback()
const {
return creation_traceback_;
}
void set_creation_meta(CreationMeta new_creation_meta);
bool has_fw_view() const {
return shared_view_info_ || forward_info_.has_value();
}

View File

@ -1,4 +1,5 @@
#include <torch/csrc/profiler/combined_traceback.h>
#include <atomic>
namespace torch {
@ -17,9 +18,11 @@ std::shared_ptr<CapturedTraceback> CapturedTraceback::gather(
p = p->next_;
}
}
#ifndef BUILD_LITE_INTERPRETER
if (script) {
r->script_frames_ = torch::jit::currentCallstack();
}
#endif
if (cpp) {
r->cpp_frames_ = unwind::unwind();
}
@ -114,6 +117,7 @@ SymbolizedTracebacks symbolize(
};
auto append_jit = [&]() {
#ifndef BUILD_LITE_INTERPRETER
if (jit_appended) {
return;
}
@ -133,6 +137,7 @@ SymbolizedTracebacks symbolize(
r.tracebacks.back().push_back(r.all_frames.size());
r.all_frames.emplace_back(std::move(frame));
}
#endif
};
for (void* f : sc->cpp_frames_) {

View File

@ -1,7 +1,12 @@
#pragma once
#ifndef BUILD_LITE_INTERPRETER
#include <torch/csrc/jit/runtime/interpreter.h>
#endif
#include <c10/core/Allocator.h>
#include <c10/util/Exception.h>
#include <torch/csrc/profiler/unwind/unwind.h>
#include <unordered_map>
namespace torch {
@ -47,7 +52,9 @@ struct TORCH_API CapturedTraceback : public c10::GatheredContext {
private:
std::vector<PyFrame> frames_;
std::vector<void*> cpp_frames_;
#ifndef BUILD_LITE_INTERPRETER
std::vector<jit::StackEntry> script_frames_;
#endif
friend TORCH_API SymbolizedTracebacks
symbolize(const std::vector<CapturedTraceback*>& to_symbolize);