mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Follows #133180 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133295 Approved by: https://github.com/Skylion007
78 lines
2.2 KiB
C++
78 lines
2.2 KiB
C++
#include <c10/util/Backtrace.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/autograd/anomaly_mode.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <mutex>
|
|
|
|
namespace torch::autograd {
|
|
|
|
bool AnomalyMode::_enabled = false;
|
|
bool AnomalyMode::_check_nan = true;
|
|
|
|
namespace {
|
|
std::mutex& get_anomaly_guard_lock() {
|
|
static std::mutex anomaly_guard_lock{};
|
|
return anomaly_guard_lock;
|
|
}
|
|
|
|
uint32_t& get_anomaly_counter() {
|
|
static uint32_t counter = 0;
|
|
return counter;
|
|
}
|
|
} // namespace
|
|
|
|
DetectAnomalyGuard::DetectAnomalyGuard(bool check_nan) {
|
|
TORCH_WARN_ONCE(
|
|
"This mode should be enabled only for debugging as the different tests will slow down your program execution.");
|
|
std::lock_guard<std::mutex> lock(get_anomaly_guard_lock());
|
|
uint32_t& counter = get_anomaly_counter();
|
|
counter++;
|
|
this->prev_check_nan_ = AnomalyMode::should_check_nan();
|
|
AnomalyMode::set_enabled(true, check_nan);
|
|
}
|
|
|
|
DetectAnomalyGuard::~DetectAnomalyGuard() {
|
|
std::lock_guard<std::mutex> lock(get_anomaly_guard_lock());
|
|
uint32_t& counter = get_anomaly_counter();
|
|
counter--;
|
|
AnomalyMode::set_enabled(counter > 0, this->prev_check_nan_);
|
|
}
|
|
|
|
AnomalyMetadata::~AnomalyMetadata() = default;
|
|
|
|
void AnomalyMetadata::store_stack() {
|
|
traceback_ = c10::get_backtrace(/* frames_to_skip */ 1);
|
|
}
|
|
|
|
void AnomalyMetadata::print_stack(const std::string& current_node_name) {
|
|
TORCH_WARN(
|
|
"Error detected in ",
|
|
current_node_name,
|
|
". ",
|
|
"Traceback of forward call that caused the error:\n",
|
|
traceback_);
|
|
|
|
auto& cur_parent = parent_;
|
|
// if there is no "parent_" in metadata, then it means this metadata's node
|
|
// is the root and stop printing the traceback
|
|
while (cur_parent) {
|
|
auto parent_metadata = cur_parent->metadata();
|
|
TORCH_WARN(
|
|
"\n\n",
|
|
"Previous calculation was induced by ",
|
|
cur_parent->name(),
|
|
". "
|
|
"Traceback of forward call that induced the previous calculation:\n",
|
|
parent_metadata->traceback_);
|
|
// get the parent of this node, if this node is a root, pyparent is simply
|
|
// null
|
|
cur_parent = parent_metadata->parent_;
|
|
}
|
|
}
|
|
|
|
void AnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node) {
|
|
parent_ = parent_node;
|
|
}
|
|
|
|
} // namespace torch::autograd
|