Files
pytorch/test/jit/test_recursive_script.py
David Berard 78a2fe1d42 [TorchScript] thread-safe ErrorReport::CallStack (#160386)
Context: During jit.script, the TorchScript frontend maintains a callstack of Python frames, which is used to present the corresponding user code in case TorchScript errors. The callstack is maintained via ErrorReport::CallStack RAII guards. Before recursing into a function, an ErrorReport::CallStack guard is created and the CallStack guard pushes the frame information onto a thread_local callstack (a list of calls); and after exiting, the frame information is popped off the callstack. Note that the CallStack guards are also sometimes used in python via pybindings.

The problem is that sometimes another thread can obtain a reference to the CallStack guard (if it's a Python CallStack guard). **This means that the destructor for a CallStack guard can be called from a different thread than the constructor was called**. When this happens, it causes a segfault.

This PR makes the callstack vector thread-safe to access, and each CallStack guard will store a reference to the callstack vector onto which it pushed. When the CallStack guard is destructed, it pops off the appropriate callstack vector. Although this could potentially lead to mangled callstacks, it should prevent segfaults.

Added a test `test_thread_safe_error_stacks` which segfaults prior to these changes, and no longer segfaults.

Differential Revision: [D80054972](https://our.internmc.facebook.com/intern/diff/D80054972)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160386
Approved by: https://github.com/eellison
2025-08-12 21:59:04 +00:00

819 lines
23 KiB
Python

# Owner(s): ["oncall: jit"]
# ruff: noqa: F841
import os
import re
import sys
import threading
import types
import typing
import typing_extensions
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import torch
import torch.jit.frontend
import torch.nn as nn
from torch import Tensor
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import (
_tmp_donotuse_dont_inline_everything,
JitTestCase,
)
class TestRecursiveScript(JitTestCase):
def test_inferred_nonetype(self):
class M(nn.Module):
def __init__(self) -> None:
super().__init__()
self.x = None
def forward(self):
assert self.x is None
m = torch.jit.script(M())
self.checkModule(M(), ())
def test_script_function_attribute(self):
@torch.jit.script
def fn1(x):
return x + x
@torch.jit.script
def fn2(x):
return x - x
class M(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
fn1_mod = M(fn1)
fn2_mod = M(fn2)
self.checkModule(fn1_mod, (torch.randn(2, 2),))
self.checkModule(fn2_mod, (torch.randn(2, 2),))
def test_python_function_attribute(self):
class M(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
mod = M(torch.sigmoid)
self.checkModule(mod, (torch.randn(2, 2),))
def test_failed_function_compilation(self):
def fn(x):
return i_dont_exist # noqa: F821
class M(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
m = M(fn)
with self.assertRaisesRegexWithHighlight(
RuntimeError, "failed to compile", "i_dont_exist"
):
torch.jit.script(m)
def test_init_error(self):
class M(nn.Module):
def __init__(self) -> None:
self.x = 2
def forward(self):
pass
with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
torch.jit.script(M())
def test_script_after_eval(self):
class M(nn.Module):
def forward(self):
if self.training:
return 2
else:
return 0
m = M()
sm1 = torch.jit.script(m)
m.eval()
sm2 = torch.jit.script(m)
# m is in eval mode, training should be False
self.assertFalse(m.training)
# sm1 was created while m had training = True
self.assertTrue(sm1.training)
self.assertEqual(sm1.training, sm1._c.getattr("training"))
self.assertEqual(sm1(), 2)
# sm2 was created after m was eval'ed
self.assertFalse(sm2.training)
self.assertEqual(sm2.training, sm2._c.getattr("training"))
self.assertEqual(sm2(), 0)
def test_module_name(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.x = 2
def forward(self, t):
return t + self.x
m = torch.jit.script(MyModule())
FileCheck().check("MyModule").run(m.graph)
def test_repeated_error_stack(self):
def d(x):
return "a" - 2
def c(x):
return d(x)
def b(x):
return c(x)
def a(x):
return b(x)
try:
torch.jit.script(a)
except Exception as e:
FileCheck().check_count("is being compiled", 2).run(str(e))
try:
torch.jit.script(a)
except Exception as e:
# Make sure that no entries are left over from the previous failure
FileCheck().check_count("is being compiled", 2).run(str(e))
def test_constants_with_final(self):
class M1(torch.nn.Module):
x: torch.jit.Final[int]
def __init__(self) -> None:
super().__init__()
self.x = 2
def forward(self, t):
return t + self.x
self.checkModule(M1(), (torch.randn(2, 2),))
class M2(torch.nn.Module):
x: typing_extensions.Final[int]
def __init__(self) -> None:
super().__init__()
self.x = 2
def forward(self, t):
return t + self.x
self.checkModule(M2(), (torch.randn(2, 2),))
class M3(torch.nn.Module):
x: typing.Final[int]
def __init__(self) -> None:
super().__init__()
self.x = 2
def forward(self, t):
return t + self.x
self.checkModule(M3(), (torch.randn(2, 2),))
def test_ignore_class(self):
@torch.jit.ignore
class MyScriptClass:
def unscriptable(self):
return "a" + 200
class TestModule(torch.nn.Module):
def forward(self, x):
return MyScriptClass()
with self.assertRaisesRegexWithHighlight(
torch.jit.frontend.FrontendError,
"Cannot instantiate class",
"MyScriptClass",
):
t = torch.jit.script(TestModule())
def test_method_call(self):
class M(nn.Module):
def test(self, x):
return x
def forward(self, z):
y = self.test(z)
return z + 20 + y
self.checkModule(M(), (torch.randn(2, 2),))
def test_module_repr(self):
class Submodule(nn.Module):
def forward(self, x):
return x
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(10, 10, 3)
self.lin = nn.Linear(10, 10)
self.sub = Submodule()
def forward(self, x):
return self.lin(x) + self.sub(x) + self.conv(x)
m = torch.jit.script(MyModule())
with self.capture_stdout() as out:
print(m)
f = FileCheck()
f.check("MyModule")
f.check("Conv2d")
f.check("Linear")
f.check("Submodule")
f.run(out[0])
self.assertEqual(m.original_name, "MyModule")
def test_dir(self):
def test_module_dir(mod):
dir_set = dir(mod)
scripted_mod = torch.jit.script(mod)
dir_scripted = set(dir(scripted_mod))
# set not currently copied over
ignore_set = [
"training",
"__delitem__",
"__setitem__",
"clear",
"items",
"keys",
"pop",
"update",
"values",
]
for attr in dir_set:
if attr in ignore_set:
continue
self.assertTrue(attr in dir_scripted, attr)
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(10, 10, 3)
self.lin = nn.Linear(10, 10)
def forward(self, x):
return self.lin(x) + self.conv(x)
test_module_dir(MyModule())
# test custom __dir__ for containers
conv = nn.Conv2d(10, 10, 3)
linear = nn.Linear(10, 10)
test_module_dir(nn.Sequential(conv, linear))
test_module_dir(
nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
)
def test_class_compile(self):
def other_fn(a: int, b: Tensor) -> Tensor:
return a * b
class B:
def __init__(self, x):
self.x = 2
def helper(self, a):
return self.x + a + other_fn(self.x, a)
class N(torch.nn.Module):
def forward(self, x):
b = B(x)
return b.helper(x)
self.checkModule(N(), (torch.randn(2, 2),))
def test_error_stack(self):
def d(x: int) -> int:
return x + 10
def c(x):
return d("hello") + d(x)
def b(x):
return c(x)
def a(x):
return b(x)
try:
scripted = torch.jit.script(a)
except RuntimeError as e:
checker = FileCheck()
checker.check("Expected a value of type 'int'")
checker.check("def c(x)")
checker.check("def b(x)")
checker.check("def a(x)")
checker.run(str(e))
def test_error_stack_module(self):
def d(x: int) -> int:
return x + 10
def c(x):
return d("hello") + d(x)
def b(x):
return c(x)
class Submodule(torch.nn.Module):
def forward(self, x):
return b(x)
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submodule = Submodule()
def some_method(self, y):
return y + self.submodule(y)
def forward(self, x):
return self.some_method(x)
try:
scripted = torch.jit.script(M())
except RuntimeError as e:
checker = FileCheck()
checker.check("Expected a value of type 'int'")
checker.check("'c' is being compiled since it was called from 'b'")
checker.check("'b' is being compiled since it was called from")
checker.run(str(e))
@_tmp_donotuse_dont_inline_everything
def test_script_basic(self):
def a_python_fn(a, b, c):
return a + b + c
@torch.jit.script
def a_script_fn(d, e, f):
return a_python_fn(d, e, f)
graph = str(a_script_fn.graph)
FileCheck().check("prim::CallFunction").run(graph)
FileCheck().check_not("^a_python_fn").run(graph)
t = torch.ones(2, 2)
self.assertEqual(a_script_fn(t, t, t), t + t + t)
def test_error_stack_class(self):
class X:
def bad_fn(self):
import pdb # noqa: F401
def fn(x) -> X:
return X(10)
try:
torch.jit.script(fn)
except Exception as e:
checker = FileCheck()
checker.check("import statements")
checker.check("is being compiled since it was called from")
checker.run(str(e))
def test_error_stack_annotation(self):
class X:
def bad_fn(self):
import pdb # noqa: F401
def fn(x) -> X:
return X(10)
try:
torch.jit.script(fn)
except Exception as e:
checker = FileCheck()
checker.check("import statements")
checker.check("is being compiled since it was called from")
checker.check("-> X")
checker.run(str(e))
def test_module_basic(self):
class Other(torch.nn.Module):
__constants__ = ["x"]
def __init__(self, x):
super().__init__()
self.x = x
self.param = torch.nn.Parameter(torch.ones(2, 2))
def some_unscriptable_method(self):
a = 2
a = [2]
return a
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.other = Other(200)
def forward(self, t):
return self.other(t) * 2
self.checkModule(M(), (torch.ones(2, 2),))
def test_module_function_export(self):
class Other(torch.nn.Module):
__constants__ = ["x"]
def __init__(self, x):
super().__init__()
self.x = x
self.param = torch.nn.Parameter(torch.ones(2, 2))
@torch.jit.export
def some_entry_point(self, y):
return y + 20
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.other = Other(200)
def forward(self, t):
return self.other(t) * 2
self.checkModule(M(), (torch.ones(2, 2),))
def test_iterable_modules(self):
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.sequential = nn.Sequential(
Inner(), Inner(), nn.Sequential(Inner(), Inner())
)
self.module_list = nn.ModuleList([Inner(), Inner()])
def forward(self, x):
for mod in self.module_list:
x += mod(x)
x += self.sequential(x)
return x
self.checkModule(M(), (torch.randn(5, 5),))
def test_prepare_scriptable_basic(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()
t = torch.randn(5, 5)
m = SeluButReluWhenScripted()
sm = torch.jit.script(m)
eager_out = m(t)
script_out = sm(t)
self.assertNotEqual(eager_out, script_out)
def test_prepare_scriptable_iterable_modules(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
shared = SeluButReluWhenScripted()
self.sequential = nn.Sequential(
SeluButReluWhenScripted(),
SeluButReluWhenScripted(),
nn.Sequential(
SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
),
shared,
)
self.module_list = nn.ModuleList(
[SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
)
def forward(self, x):
for mod in self.module_list:
x += mod(x)
x += self.sequential(x)
return x
t = torch.randn(5, 5)
m = M()
eager_out = m(t.clone())
sm = torch.jit.script(m)
script_out = sm(t.clone())
self.assertNotEqual(eager_out, script_out)
def test_prepare_scriptable_cycle(self):
t = torch.randn(5, 5)
c = torch.nn.Module()
p = torch.nn.Module()
c.__dict__["_p"] = p
p.__dict__["_c"] = c
sm = torch.jit.script(p)
def test_prepare_scriptable_escape_hatch(self):
class NonJitableClass:
def __call__(self, int1, int2, *args):
total = int1 + int2
for arg in args:
total += arg
return total
obj = NonJitableClass()
self.assertEqual(obj(1, 2), 3)
self.assertEqual(obj(1, 2, 3, 4), 10)
with self.assertRaisesRegex(
torch.jit.frontend.NotSupportedError,
expected_regex="can't take variable number of arguments",
):
torch.jit.script(obj)
def escape_hatch(int1: int, int2: int) -> int:
return int1 + int2
class NonJitableClassWithEscapeHatch(NonJitableClass):
def __prepare_scriptable__(self):
return escape_hatch
jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch())
self.assertEqual(jit_obj(1, 2), 3)
with self.assertRaisesRegex(
RuntimeError,
expected_regex=re.escape(
"expected at most 2 argument(s) but received 4 argument(s)"
),
):
jit_obj(1, 2, 3, 4)
def test_attributes(self):
@torch.jit.script
class Inner2:
def __init__(self) -> None:
self.b = "a string"
@torch.jit.script
class Foo:
def __init__(self) -> None:
self.a = 4
self.inner = Inner2()
@torch.jit.script
class SFoo:
def __init__(self) -> None:
self.a = 4
self.inner = Inner2()
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
a, inner = obj
self.a = a
self.inner = inner
def __getstate__(self):
return (self.a, self.inner)
untyped_values = (
("my_dict", {"I": "am", "a test": "test"}),
("my_float", 2.3),
("my_int", 99),
("my_bool", False),
("my_tuple", (1, 2, 3, 4)),
("my_list", [(1, 2), (3, 4)]),
# ('my_tensor', torch.randn(2, 2)),
("my_int_list", [1, 2, 3, 4]),
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
("my_bool_list", [True, True, False, True]),
("my_float_list", [1.0, 2.0, 3.0, 4.0]),
("my_str_list", ["hello", "bye"]),
)
typed_values = (
("my_empty_list", []),
("my_empty_dict", {}),
("my_none", None),
("my_object", Foo()),
("my_object2", SFoo()),
)
class M(torch.nn.Module):
# TODO: re-enable this once this test is in a Python 3-only syntax
# file
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]
def forward(self, x):
return (
self.my_dict,
self.my_float,
self.my_int,
self.my_bool,
# self.my_tensor,
self.my_int_list,
# self.my_tensor_list,
self.my_bool_list,
self.my_float_list,
self.my_str_list,
self.my_empty_list,
self.my_empty_dict,
self.my_none,
self.my_object.a,
self.my_object.inner.b,
self.my_object.a,
self.my_object2.inner.b,
)
# TODO: as a followup, fix this test
# We can't define class attributes like we should be doing:
# class M(torch.nn.Module):
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]
# my_out_of_line_attribute: List[int] = [1, 2, 3]
# since there's no string frontend for Python classes (so the `define`)
# trick doesn't work.
M.__annotations__ = {
"my_empty_list": List[int],
"my_empty_dict": Dict[str, int],
"my_none": Optional[int],
"my_object": Foo,
"my_object2": SFoo,
}
m = M()
for name, value in untyped_values + typed_values:
setattr(m, name, value)
self.checkModule(m, (torch.randn(5, 5),))
def test_function_attribute_in_submodule(self):
class N(nn.Module):
def __init__(self, norm):
super().__init__()
self.activation = torch.nn.functional.relu
self.norm = norm
def forward(self, src):
output = src
output = self.norm(output)
return output
class M(nn.Module):
def __init__(self) -> None:
super().__init__()
encoder_norm = nn.ReLU()
self.encoder = N(encoder_norm)
def forward(self, x):
return self.encoder(x)
m = M()
self.checkModule(m, (torch.randn(5, 5),))
def test_inner_traced_module(self):
class Dummy(nn.Module):
def forward(self, x):
return x
class Model(nn.Module):
def __init__(self, dummies):
super().__init__()
self._dummies = dummies
def forward(self, x):
out = []
for dummy in self._dummies:
out.append(dummy(x))
return out
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
dummies = nn.ModuleList([dummy])
model = Model(dummies)
self.checkModule(model, (torch.rand(5, 5),))
def test_script_loaded_module(self):
"""
Test that we can hold a loaded ScriptModule as a submodule.
"""
class Dummy(nn.Module):
def forward(self, x):
return x
dummy = torch.jit.script(Dummy())
dummy = self.getExportImportCopy(dummy)
class ContainsLoaded(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.encoder = dummy
def forward(self, input):
return self.encoder(input)
self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
def test_optional_module(self):
class Dummy(nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = nn.Linear(2, 2)
def forward(self, x):
if self.foo is not None:
return self.foo(x)
return x
mod = Dummy()
self.checkModule(mod, (torch.rand(2, 2),))
mod.foo = None
self.checkModule(mod, (torch.rand(2, 2),))
def test_thread_safe_error_stacks(self):
# prior to #160386, this causes a segfault. See [Note: Thread-safe CallStack]
callstacks = []
def callstack_creator():
factory = torch._C._jit_tree_views.SourceRangeFactory(
"source code", "a.py", 1, 0
)
x = torch._C.CallStack("a", factory.make_range(1, 0, 1))
callstacks.append(x)
del x
t = threading.Thread(target=callstack_creator)
t.start()
t.join()
del t
del callstacks[0]
self.assertTrue(len(callstacks) == 0)
def test_override_instance_method_ignore(self):
class M(torch.nn.Module):
@torch.jit.ignore
def i_am_ignored(self):
return "old"
m = M()
# Override the ignored method by binding a new method to this instance.
@torch.jit.ignore
def i_am_ignored(self):
return "new"
m.i_am_ignored = types.MethodType(i_am_ignored, m)
self.assertEqual(m.i_am_ignored(), "new")
# ScriptModule should correctly reflect the override.
s = torch.jit.script(m)
self.assertEqual(s.i_am_ignored(), "new")
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")