mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TorchScript] thread-safe ErrorReport::CallStack (#160386)
Context: During jit.script, the TorchScript frontend maintains a callstack of Python frames, which is used to present the corresponding user code in case TorchScript errors. The callstack is maintained via ErrorReport::CallStack RAII guards. Before recursing into a function, an ErrorReport::CallStack guard is created and the CallStack guard pushes the frame information onto a thread_local callstack (a list of calls); and after exiting, the frame information is popped off the callstack. Note that the CallStack guards are also sometimes used in python via pybindings. The problem is that sometimes another thread can obtain a reference to the CallStack guard (if it's a Python CallStack guard). **This means that the destructor for a CallStack guard can be called from a different thread than the constructor was called**. When this happens, it causes a segfault. This PR makes the callstack vector thread-safe to access, and each CallStack guard will store a reference to the callstack vector onto which it pushed. When the CallStack guard is destructed, it pops off the appropriate callstack vector. Although this could potentially lead to mangled callstacks, it should prevent segfaults. Added a test `test_thread_safe_error_stacks` which segfaults prior to these changes, and no longer segfaults. Differential Revision: [D80054972](https://our.internmc.facebook.com/intern/diff/D80054972) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160386 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
f8f0414a59
commit
78a2fe1d42
@ -4,6 +4,7 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import typing
|
||||
import typing_extensions
|
||||
@ -773,6 +774,25 @@ class TestRecursiveScript(JitTestCase):
|
||||
mod.foo = None
|
||||
self.checkModule(mod, (torch.rand(2, 2),))
|
||||
|
||||
def test_thread_safe_error_stacks(self):
|
||||
# prior to #160386, this causes a segfault. See [Note: Thread-safe CallStack]
|
||||
callstacks = []
|
||||
|
||||
def callstack_creator():
|
||||
factory = torch._C._jit_tree_views.SourceRangeFactory(
|
||||
"source code", "a.py", 1, 0
|
||||
)
|
||||
x = torch._C.CallStack("a", factory.make_range(1, 0, 1))
|
||||
callstacks.append(x)
|
||||
del x
|
||||
|
||||
t = threading.Thread(target=callstack_creator)
|
||||
t.start()
|
||||
t.join()
|
||||
del t
|
||||
del callstacks[0]
|
||||
self.assertTrue(len(callstacks) == 0)
|
||||
|
||||
def test_override_instance_method_ignore(self):
|
||||
class M(torch.nn.Module):
|
||||
@torch.jit.ignore
|
||||
|
@ -6,7 +6,34 @@ namespace torch::jit {
|
||||
|
||||
// Avoid storing objects with destructor in thread_local for mobile build.
|
||||
#ifndef C10_MOBILE
|
||||
static thread_local std::vector<Call> calls;
|
||||
// [NOTE: Thread-safe CallStack]
|
||||
// `calls` maintains a stack of Python calls that resulted in the
|
||||
// currently compiled TorchScript code. RAII ErrorReport::CallStack
|
||||
// push and pop from the `calls` object during compilation to track
|
||||
// these stacks so that they can be used to report compilation errors
|
||||
//
|
||||
// Q: Why can't this just be a thread_local vector<Call> (as it was previously)?
|
||||
//
|
||||
// A: Sometimes a CallStack RAII guard is created in Python in a given
|
||||
// thread (say, thread A). Then later, someone can call
|
||||
// sys._current_frames() from another thread (thread B), which causes
|
||||
// thread B to hold references to the CallStack guard. e.g.
|
||||
// 1. CallStack RAII guard created by thread A
|
||||
// 2. CallStack guard now has a reference from thread B
|
||||
// 3. thread A releases guard, but thread B still holds a reference
|
||||
// 4. thread B releases guard, refcount goes to 0, and we
|
||||
// call the destructor
|
||||
// under this situation, **we pop an element off the wrong `call`
|
||||
// object (from the wrong thread!)
|
||||
//
|
||||
// To fix this:
|
||||
// * in CallStack, store a reference to which thread's `calls`
|
||||
// the CallStack corresponds to, so you can pop from the correct
|
||||
// `calls` object.
|
||||
// * make it a shared_ptr and add a mutex to make this thread safe
|
||||
// (since now multiple threads access a given thread_local calls object)
|
||||
static thread_local std::shared_ptr<ErrorReport::Calls> calls =
|
||||
std::make_shared<ErrorReport::Calls>();
|
||||
#endif // C10_MOBILE
|
||||
|
||||
ErrorReport::ErrorReport(const ErrorReport& e)
|
||||
@ -17,20 +44,23 @@ ErrorReport::ErrorReport(const ErrorReport& e)
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
ErrorReport::ErrorReport(const SourceRange& r)
|
||||
: context(r), error_stack(calls.begin(), calls.end()) {}
|
||||
: context(r), error_stack(calls->get_stack()) {}
|
||||
|
||||
void ErrorReport::CallStack::update_pending_range(const SourceRange& range) {
|
||||
calls.back().caller_range = range;
|
||||
calls->update_pending_range(range);
|
||||
}
|
||||
|
||||
ErrorReport::CallStack::CallStack(
|
||||
const std::string& name,
|
||||
const SourceRange& range) {
|
||||
calls.push_back({name, range});
|
||||
source_callstack_ = calls;
|
||||
source_callstack_->push_back({name, range});
|
||||
}
|
||||
|
||||
ErrorReport::CallStack::~CallStack() {
|
||||
calls.pop_back();
|
||||
if (source_callstack_) {
|
||||
source_callstack_->pop_back();
|
||||
}
|
||||
}
|
||||
#else // defined C10_MOBILE
|
||||
ErrorReport::ErrorReport(const SourceRange& r) : context(r) {}
|
||||
@ -61,7 +91,7 @@ static std::string get_stacked_errors(const std::vector<Call>& error_stack) {
|
||||
|
||||
std::string ErrorReport::current_call_stack() {
|
||||
#ifndef C10_MOBILE
|
||||
return get_stacked_errors(calls);
|
||||
return get_stacked_errors(calls->get_stack());
|
||||
#else
|
||||
TORCH_CHECK(false, "Call stack not supported on mobile");
|
||||
#endif // C10_MOBILE
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/frontend/tree.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
@ -18,6 +19,38 @@ struct TORCH_API ErrorReport : public std::exception {
|
||||
|
||||
const char* what() const noexcept override;
|
||||
|
||||
class TORCH_API Calls {
|
||||
private:
|
||||
std::vector<Call> calls_;
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
public:
|
||||
void push_back(Call call) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
calls_.push_back(std::move(call));
|
||||
}
|
||||
|
||||
void pop_back() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
calls_.pop_back();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return calls_.empty();
|
||||
}
|
||||
|
||||
void update_pending_range(const SourceRange& range) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
calls_.back().caller_range = range;
|
||||
}
|
||||
|
||||
std::vector<Call> get_stack() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return calls_;
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API CallStack {
|
||||
// These functions are used to report why a function was being compiled
|
||||
// (i.e. what was the call stack of user functions at compilation time that
|
||||
@ -28,6 +61,9 @@ struct TORCH_API ErrorReport : public std::exception {
|
||||
// Change the range that is relevant for the current function (i.e. after
|
||||
// each successful expression compilation, change it to the next expression)
|
||||
static void update_pending_range(const SourceRange& range);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Calls> source_callstack_;
|
||||
};
|
||||
|
||||
static std::string current_call_stack();
|
||||
|
Reference in New Issue
Block a user