[TorchScript] thread-safe ErrorReport::CallStack (#160386)

Context: During jit.script, the TorchScript frontend maintains a callstack of Python frames, which is used to present the corresponding user code in case TorchScript errors. The callstack is maintained via ErrorReport::CallStack RAII guards. Before recursing into a function, an ErrorReport::CallStack guard is created and the CallStack guard pushes the frame information onto a thread_local callstack (a list of calls); and after exiting, the frame information is popped off the callstack. Note that the CallStack guards are also sometimes used in python via pybindings.

The problem is that sometimes another thread can obtain a reference to the CallStack guard (if it's a Python CallStack guard). **This means that the destructor for a CallStack guard can be called from a different thread than the constructor was called**. When this happens, it causes a segfault.

This PR makes the callstack vector thread-safe to access, and each CallStack guard will store a reference to the callstack vector onto which it pushed. When the CallStack guard is destructed, it pops off the appropriate callstack vector. Although this could potentially lead to mangled callstacks, it should prevent segfaults.

Added a test `test_thread_safe_error_stacks` which segfaults prior to these changes, and no longer segfaults.

Differential Revision: [D80054972](https://our.internmc.facebook.com/intern/diff/D80054972)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160386
Approved by: https://github.com/eellison
This commit is contained in:
David Berard
2025-08-12 11:47:04 -07:00
committed by PyTorch MergeBot
parent f8f0414a59
commit 78a2fe1d42
3 changed files with 92 additions and 6 deletions

View File

@ -4,6 +4,7 @@
import os
import re
import sys
import threading
import types
import typing
import typing_extensions
@ -773,6 +774,25 @@ class TestRecursiveScript(JitTestCase):
mod.foo = None
self.checkModule(mod, (torch.rand(2, 2),))
def test_thread_safe_error_stacks(self):
# prior to #160386, this causes a segfault. See [Note: Thread-safe CallStack]
callstacks = []
def callstack_creator():
factory = torch._C._jit_tree_views.SourceRangeFactory(
"source code", "a.py", 1, 0
)
x = torch._C.CallStack("a", factory.make_range(1, 0, 1))
callstacks.append(x)
del x
t = threading.Thread(target=callstack_creator)
t.start()
t.join()
del t
del callstacks[0]
self.assertTrue(len(callstacks) == 0)
def test_override_instance_method_ignore(self):
class M(torch.nn.Module):
@torch.jit.ignore

View File

@ -6,7 +6,34 @@ namespace torch::jit {
// Avoid storing objects with destructor in thread_local for mobile build.
#ifndef C10_MOBILE
static thread_local std::vector<Call> calls;
// [NOTE: Thread-safe CallStack]
// `calls` maintains a stack of Python calls that resulted in the
// currently compiled TorchScript code. RAII ErrorReport::CallStack
// push and pop from the `calls` object during compilation to track
// these stacks so that they can be used to report compilation errors
//
// Q: Why can't this just be a thread_local vector<Call> (as it was previously)?
//
// A: Sometimes a CallStack RAII guard is created in Python in a given
// thread (say, thread A). Then later, someone can call
// sys._current_frames() from another thread (thread B), which causes
// thread B to hold references to the CallStack guard. e.g.
// 1. CallStack RAII guard created by thread A
// 2. CallStack guard now has a reference from thread B
// 3. thread A releases guard, but thread B still holds a reference
// 4. thread B releases guard, refcount goes to 0, and we
// call the destructor
// under this situation, **we pop an element off the wrong `call`
// object (from the wrong thread!)
//
// To fix this:
// * in CallStack, store a reference to which thread's `calls`
// the CallStack corresponds to, so you can pop from the correct
// `calls` object.
// * make it a shared_ptr and add a mutex to make this thread safe
// (since now multiple threads access a given thread_local calls object)
static thread_local std::shared_ptr<ErrorReport::Calls> calls =
std::make_shared<ErrorReport::Calls>();
#endif // C10_MOBILE
ErrorReport::ErrorReport(const ErrorReport& e)
@ -17,20 +44,23 @@ ErrorReport::ErrorReport(const ErrorReport& e)
#ifndef C10_MOBILE
ErrorReport::ErrorReport(const SourceRange& r)
: context(r), error_stack(calls.begin(), calls.end()) {}
: context(r), error_stack(calls->get_stack()) {}
void ErrorReport::CallStack::update_pending_range(const SourceRange& range) {
calls.back().caller_range = range;
calls->update_pending_range(range);
}
ErrorReport::CallStack::CallStack(
const std::string& name,
const SourceRange& range) {
calls.push_back({name, range});
source_callstack_ = calls;
source_callstack_->push_back({name, range});
}
ErrorReport::CallStack::~CallStack() {
calls.pop_back();
if (source_callstack_) {
source_callstack_->pop_back();
}
}
#else // defined C10_MOBILE
ErrorReport::ErrorReport(const SourceRange& r) : context(r) {}
@ -61,7 +91,7 @@ static std::string get_stacked_errors(const std::vector<Call>& error_stack) {
std::string ErrorReport::current_call_stack() {
#ifndef C10_MOBILE
return get_stacked_errors(calls);
return get_stacked_errors(calls->get_stack());
#else
TORCH_CHECK(false, "Call stack not supported on mobile");
#endif // C10_MOBILE

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/jit/frontend/tree.h>
#include <mutex>
namespace torch::jit {
@ -18,6 +19,38 @@ struct TORCH_API ErrorReport : public std::exception {
const char* what() const noexcept override;
class TORCH_API Calls {
private:
std::vector<Call> calls_;
mutable std::mutex mutex_;
public:
void push_back(Call call) {
std::lock_guard<std::mutex> lock(mutex_);
calls_.push_back(std::move(call));
}
void pop_back() {
std::lock_guard<std::mutex> lock(mutex_);
calls_.pop_back();
}
bool empty() const {
std::lock_guard<std::mutex> lock(mutex_);
return calls_.empty();
}
void update_pending_range(const SourceRange& range) {
std::lock_guard<std::mutex> lock(mutex_);
calls_.back().caller_range = range;
}
std::vector<Call> get_stack() const {
std::lock_guard<std::mutex> lock(mutex_);
return calls_;
}
};
struct TORCH_API CallStack {
// These functions are used to report why a function was being compiled
// (i.e. what was the call stack of user functions at compilation time that
@ -28,6 +61,9 @@ struct TORCH_API ErrorReport : public std::exception {
// Change the range that is relevant for the current function (i.e. after
// each successful expression compilation, change it to the next expression)
static void update_pending_range(const SourceRange& range);
private:
std::shared_ptr<Calls> source_callstack_;
};
static std::string current_call_stack();