mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Context: During jit.script, the TorchScript frontend maintains a callstack of Python frames, which is used to present the corresponding user code in case TorchScript errors. The callstack is maintained via ErrorReport::CallStack RAII guards. Before recursing into a function, an ErrorReport::CallStack guard is created and the CallStack guard pushes the frame information onto a thread_local callstack (a list of calls); and after exiting, the frame information is popped off the callstack. Note that the CallStack guards are also sometimes used in python via pybindings. The problem is that sometimes another thread can obtain a reference to the CallStack guard (if it's a Python CallStack guard). **This means that the destructor for a CallStack guard can be called from a different thread than the constructor was called**. When this happens, it causes a segfault. This PR makes the callstack vector thread-safe to access, and each CallStack guard will store a reference to the callstack vector onto which it pushed. When the CallStack guard is destructed, it pops off the appropriate callstack vector. Although this could potentially lead to mangled callstacks, it should prevent segfaults. Added a test `test_thread_safe_error_stacks` which segfaults prior to these changes, and no longer segfaults. Differential Revision: [D80054972](https://our.internmc.facebook.com/intern/diff/D80054972) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160386 Approved by: https://github.com/eellison
819 lines
23 KiB
Python
819 lines
23 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
# ruff: noqa: F841
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import threading
|
|
import types
|
|
import typing
|
|
import typing_extensions
|
|
from collections import OrderedDict
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.jit.frontend
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.testing import FileCheck
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import (
|
|
_tmp_donotuse_dont_inline_everything,
|
|
JitTestCase,
|
|
)
|
|
|
|
|
|
class TestRecursiveScript(JitTestCase):
|
|
def test_inferred_nonetype(self):
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = None
|
|
|
|
def forward(self):
|
|
assert self.x is None
|
|
|
|
m = torch.jit.script(M())
|
|
self.checkModule(M(), ())
|
|
|
|
def test_script_function_attribute(self):
|
|
@torch.jit.script
|
|
def fn1(x):
|
|
return x + x
|
|
|
|
@torch.jit.script
|
|
def fn2(x):
|
|
return x - x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x)
|
|
|
|
fn1_mod = M(fn1)
|
|
fn2_mod = M(fn2)
|
|
|
|
self.checkModule(fn1_mod, (torch.randn(2, 2),))
|
|
self.checkModule(fn2_mod, (torch.randn(2, 2),))
|
|
|
|
def test_python_function_attribute(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x)
|
|
|
|
mod = M(torch.sigmoid)
|
|
|
|
self.checkModule(mod, (torch.randn(2, 2),))
|
|
|
|
def test_failed_function_compilation(self):
|
|
def fn(x):
|
|
return i_dont_exist # noqa: F821
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x)
|
|
|
|
m = M(fn)
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError, "failed to compile", "i_dont_exist"
|
|
):
|
|
torch.jit.script(m)
|
|
|
|
def test_init_error(self):
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
self.x = 2
|
|
|
|
def forward(self):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
|
|
torch.jit.script(M())
|
|
|
|
def test_script_after_eval(self):
|
|
class M(nn.Module):
|
|
def forward(self):
|
|
if self.training:
|
|
return 2
|
|
else:
|
|
return 0
|
|
|
|
m = M()
|
|
sm1 = torch.jit.script(m)
|
|
m.eval()
|
|
sm2 = torch.jit.script(m)
|
|
|
|
# m is in eval mode, training should be False
|
|
self.assertFalse(m.training)
|
|
|
|
# sm1 was created while m had training = True
|
|
self.assertTrue(sm1.training)
|
|
self.assertEqual(sm1.training, sm1._c.getattr("training"))
|
|
self.assertEqual(sm1(), 2)
|
|
|
|
# sm2 was created after m was eval'ed
|
|
self.assertFalse(sm2.training)
|
|
self.assertEqual(sm2.training, sm2._c.getattr("training"))
|
|
self.assertEqual(sm2(), 0)
|
|
|
|
def test_module_name(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
m = torch.jit.script(MyModule())
|
|
FileCheck().check("MyModule").run(m.graph)
|
|
|
|
def test_repeated_error_stack(self):
|
|
def d(x):
|
|
return "a" - 2
|
|
|
|
def c(x):
|
|
return d(x)
|
|
|
|
def b(x):
|
|
return c(x)
|
|
|
|
def a(x):
|
|
return b(x)
|
|
|
|
try:
|
|
torch.jit.script(a)
|
|
except Exception as e:
|
|
FileCheck().check_count("is being compiled", 2).run(str(e))
|
|
|
|
try:
|
|
torch.jit.script(a)
|
|
except Exception as e:
|
|
# Make sure that no entries are left over from the previous failure
|
|
FileCheck().check_count("is being compiled", 2).run(str(e))
|
|
|
|
def test_constants_with_final(self):
|
|
class M1(torch.nn.Module):
|
|
x: torch.jit.Final[int]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
self.checkModule(M1(), (torch.randn(2, 2),))
|
|
|
|
class M2(torch.nn.Module):
|
|
x: typing_extensions.Final[int]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
self.checkModule(M2(), (torch.randn(2, 2),))
|
|
|
|
class M3(torch.nn.Module):
|
|
x: typing.Final[int]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = 2
|
|
|
|
def forward(self, t):
|
|
return t + self.x
|
|
|
|
self.checkModule(M3(), (torch.randn(2, 2),))
|
|
|
|
def test_ignore_class(self):
|
|
@torch.jit.ignore
|
|
class MyScriptClass:
|
|
def unscriptable(self):
|
|
return "a" + 200
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return MyScriptClass()
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
torch.jit.frontend.FrontendError,
|
|
"Cannot instantiate class",
|
|
"MyScriptClass",
|
|
):
|
|
t = torch.jit.script(TestModule())
|
|
|
|
def test_method_call(self):
|
|
class M(nn.Module):
|
|
def test(self, x):
|
|
return x
|
|
|
|
def forward(self, z):
|
|
y = self.test(z)
|
|
return z + 20 + y
|
|
|
|
self.checkModule(M(), (torch.randn(2, 2),))
|
|
|
|
def test_module_repr(self):
|
|
class Submodule(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(10, 10, 3)
|
|
self.lin = nn.Linear(10, 10)
|
|
self.sub = Submodule()
|
|
|
|
def forward(self, x):
|
|
return self.lin(x) + self.sub(x) + self.conv(x)
|
|
|
|
m = torch.jit.script(MyModule())
|
|
|
|
with self.capture_stdout() as out:
|
|
print(m)
|
|
|
|
f = FileCheck()
|
|
f.check("MyModule")
|
|
f.check("Conv2d")
|
|
f.check("Linear")
|
|
f.check("Submodule")
|
|
f.run(out[0])
|
|
|
|
self.assertEqual(m.original_name, "MyModule")
|
|
|
|
def test_dir(self):
|
|
def test_module_dir(mod):
|
|
dir_set = dir(mod)
|
|
scripted_mod = torch.jit.script(mod)
|
|
dir_scripted = set(dir(scripted_mod))
|
|
# set not currently copied over
|
|
ignore_set = [
|
|
"training",
|
|
"__delitem__",
|
|
"__setitem__",
|
|
"clear",
|
|
"items",
|
|
"keys",
|
|
"pop",
|
|
"update",
|
|
"values",
|
|
]
|
|
for attr in dir_set:
|
|
if attr in ignore_set:
|
|
continue
|
|
self.assertTrue(attr in dir_scripted, attr)
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(10, 10, 3)
|
|
self.lin = nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.lin(x) + self.conv(x)
|
|
|
|
test_module_dir(MyModule())
|
|
|
|
# test custom __dir__ for containers
|
|
conv = nn.Conv2d(10, 10, 3)
|
|
linear = nn.Linear(10, 10)
|
|
|
|
test_module_dir(nn.Sequential(conv, linear))
|
|
test_module_dir(
|
|
nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
|
|
)
|
|
|
|
def test_class_compile(self):
|
|
def other_fn(a: int, b: Tensor) -> Tensor:
|
|
return a * b
|
|
|
|
class B:
|
|
def __init__(self, x):
|
|
self.x = 2
|
|
|
|
def helper(self, a):
|
|
return self.x + a + other_fn(self.x, a)
|
|
|
|
class N(torch.nn.Module):
|
|
def forward(self, x):
|
|
b = B(x)
|
|
return b.helper(x)
|
|
|
|
self.checkModule(N(), (torch.randn(2, 2),))
|
|
|
|
def test_error_stack(self):
|
|
def d(x: int) -> int:
|
|
return x + 10
|
|
|
|
def c(x):
|
|
return d("hello") + d(x)
|
|
|
|
def b(x):
|
|
return c(x)
|
|
|
|
def a(x):
|
|
return b(x)
|
|
|
|
try:
|
|
scripted = torch.jit.script(a)
|
|
except RuntimeError as e:
|
|
checker = FileCheck()
|
|
checker.check("Expected a value of type 'int'")
|
|
checker.check("def c(x)")
|
|
checker.check("def b(x)")
|
|
checker.check("def a(x)")
|
|
checker.run(str(e))
|
|
|
|
def test_error_stack_module(self):
|
|
def d(x: int) -> int:
|
|
return x + 10
|
|
|
|
def c(x):
|
|
return d("hello") + d(x)
|
|
|
|
def b(x):
|
|
return c(x)
|
|
|
|
class Submodule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return b(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submodule = Submodule()
|
|
|
|
def some_method(self, y):
|
|
return y + self.submodule(y)
|
|
|
|
def forward(self, x):
|
|
return self.some_method(x)
|
|
|
|
try:
|
|
scripted = torch.jit.script(M())
|
|
except RuntimeError as e:
|
|
checker = FileCheck()
|
|
checker.check("Expected a value of type 'int'")
|
|
checker.check("'c' is being compiled since it was called from 'b'")
|
|
checker.check("'b' is being compiled since it was called from")
|
|
checker.run(str(e))
|
|
|
|
@_tmp_donotuse_dont_inline_everything
|
|
def test_script_basic(self):
|
|
def a_python_fn(a, b, c):
|
|
return a + b + c
|
|
|
|
@torch.jit.script
|
|
def a_script_fn(d, e, f):
|
|
return a_python_fn(d, e, f)
|
|
|
|
graph = str(a_script_fn.graph)
|
|
FileCheck().check("prim::CallFunction").run(graph)
|
|
FileCheck().check_not("^a_python_fn").run(graph)
|
|
t = torch.ones(2, 2)
|
|
self.assertEqual(a_script_fn(t, t, t), t + t + t)
|
|
|
|
def test_error_stack_class(self):
|
|
class X:
|
|
def bad_fn(self):
|
|
import pdb # noqa: F401
|
|
|
|
def fn(x) -> X:
|
|
return X(10)
|
|
|
|
try:
|
|
torch.jit.script(fn)
|
|
except Exception as e:
|
|
checker = FileCheck()
|
|
checker.check("import statements")
|
|
checker.check("is being compiled since it was called from")
|
|
checker.run(str(e))
|
|
|
|
def test_error_stack_annotation(self):
|
|
class X:
|
|
def bad_fn(self):
|
|
import pdb # noqa: F401
|
|
|
|
def fn(x) -> X:
|
|
return X(10)
|
|
|
|
try:
|
|
torch.jit.script(fn)
|
|
except Exception as e:
|
|
checker = FileCheck()
|
|
checker.check("import statements")
|
|
checker.check("is being compiled since it was called from")
|
|
checker.check("-> X")
|
|
checker.run(str(e))
|
|
|
|
def test_module_basic(self):
|
|
class Other(torch.nn.Module):
|
|
__constants__ = ["x"]
|
|
|
|
def __init__(self, x):
|
|
super().__init__()
|
|
self.x = x
|
|
self.param = torch.nn.Parameter(torch.ones(2, 2))
|
|
|
|
def some_unscriptable_method(self):
|
|
a = 2
|
|
a = [2]
|
|
return a
|
|
|
|
def forward(self, t):
|
|
return t + self.x + self.param
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.other = Other(200)
|
|
|
|
def forward(self, t):
|
|
return self.other(t) * 2
|
|
|
|
self.checkModule(M(), (torch.ones(2, 2),))
|
|
|
|
def test_module_function_export(self):
|
|
class Other(torch.nn.Module):
|
|
__constants__ = ["x"]
|
|
|
|
def __init__(self, x):
|
|
super().__init__()
|
|
self.x = x
|
|
self.param = torch.nn.Parameter(torch.ones(2, 2))
|
|
|
|
@torch.jit.export
|
|
def some_entry_point(self, y):
|
|
return y + 20
|
|
|
|
def forward(self, t):
|
|
return t + self.x + self.param
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.other = Other(200)
|
|
|
|
def forward(self, t):
|
|
return self.other(t) * 2
|
|
|
|
self.checkModule(M(), (torch.ones(2, 2),))
|
|
|
|
def test_iterable_modules(self):
|
|
class Inner(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sequential = nn.Sequential(
|
|
Inner(), Inner(), nn.Sequential(Inner(), Inner())
|
|
)
|
|
self.module_list = nn.ModuleList([Inner(), Inner()])
|
|
|
|
def forward(self, x):
|
|
for mod in self.module_list:
|
|
x += mod(x)
|
|
x += self.sequential(x)
|
|
return x
|
|
|
|
self.checkModule(M(), (torch.randn(5, 5),))
|
|
|
|
def test_prepare_scriptable_basic(self):
|
|
class SeluButReluWhenScripted(torch.nn.SELU):
|
|
def __prepare_scriptable__(self):
|
|
return nn.ReLU()
|
|
|
|
t = torch.randn(5, 5)
|
|
m = SeluButReluWhenScripted()
|
|
sm = torch.jit.script(m)
|
|
eager_out = m(t)
|
|
script_out = sm(t)
|
|
self.assertNotEqual(eager_out, script_out)
|
|
|
|
def test_prepare_scriptable_iterable_modules(self):
|
|
class SeluButReluWhenScripted(torch.nn.SELU):
|
|
def __prepare_scriptable__(self):
|
|
return nn.ReLU()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
shared = SeluButReluWhenScripted()
|
|
self.sequential = nn.Sequential(
|
|
SeluButReluWhenScripted(),
|
|
SeluButReluWhenScripted(),
|
|
nn.Sequential(
|
|
SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
|
|
),
|
|
shared,
|
|
)
|
|
self.module_list = nn.ModuleList(
|
|
[SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for mod in self.module_list:
|
|
x += mod(x)
|
|
x += self.sequential(x)
|
|
return x
|
|
|
|
t = torch.randn(5, 5)
|
|
m = M()
|
|
eager_out = m(t.clone())
|
|
sm = torch.jit.script(m)
|
|
script_out = sm(t.clone())
|
|
self.assertNotEqual(eager_out, script_out)
|
|
|
|
def test_prepare_scriptable_cycle(self):
|
|
t = torch.randn(5, 5)
|
|
c = torch.nn.Module()
|
|
p = torch.nn.Module()
|
|
c.__dict__["_p"] = p
|
|
p.__dict__["_c"] = c
|
|
|
|
sm = torch.jit.script(p)
|
|
|
|
def test_prepare_scriptable_escape_hatch(self):
|
|
class NonJitableClass:
|
|
def __call__(self, int1, int2, *args):
|
|
total = int1 + int2
|
|
for arg in args:
|
|
total += arg
|
|
return total
|
|
|
|
obj = NonJitableClass()
|
|
|
|
self.assertEqual(obj(1, 2), 3)
|
|
self.assertEqual(obj(1, 2, 3, 4), 10)
|
|
with self.assertRaisesRegex(
|
|
torch.jit.frontend.NotSupportedError,
|
|
expected_regex="can't take variable number of arguments",
|
|
):
|
|
torch.jit.script(obj)
|
|
|
|
def escape_hatch(int1: int, int2: int) -> int:
|
|
return int1 + int2
|
|
|
|
class NonJitableClassWithEscapeHatch(NonJitableClass):
|
|
def __prepare_scriptable__(self):
|
|
return escape_hatch
|
|
|
|
jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch())
|
|
|
|
self.assertEqual(jit_obj(1, 2), 3)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
expected_regex=re.escape(
|
|
"expected at most 2 argument(s) but received 4 argument(s)"
|
|
),
|
|
):
|
|
jit_obj(1, 2, 3, 4)
|
|
|
|
def test_attributes(self):
|
|
@torch.jit.script
|
|
class Inner2:
|
|
def __init__(self) -> None:
|
|
self.b = "a string"
|
|
|
|
@torch.jit.script
|
|
class Foo:
|
|
def __init__(self) -> None:
|
|
self.a = 4
|
|
self.inner = Inner2()
|
|
|
|
@torch.jit.script
|
|
class SFoo:
|
|
def __init__(self) -> None:
|
|
self.a = 4
|
|
self.inner = Inner2()
|
|
|
|
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
|
|
a, inner = obj
|
|
self.a = a
|
|
self.inner = inner
|
|
|
|
def __getstate__(self):
|
|
return (self.a, self.inner)
|
|
|
|
untyped_values = (
|
|
("my_dict", {"I": "am", "a test": "test"}),
|
|
("my_float", 2.3),
|
|
("my_int", 99),
|
|
("my_bool", False),
|
|
("my_tuple", (1, 2, 3, 4)),
|
|
("my_list", [(1, 2), (3, 4)]),
|
|
# ('my_tensor', torch.randn(2, 2)),
|
|
("my_int_list", [1, 2, 3, 4]),
|
|
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
|
|
("my_bool_list", [True, True, False, True]),
|
|
("my_float_list", [1.0, 2.0, 3.0, 4.0]),
|
|
("my_str_list", ["hello", "bye"]),
|
|
)
|
|
typed_values = (
|
|
("my_empty_list", []),
|
|
("my_empty_dict", {}),
|
|
("my_none", None),
|
|
("my_object", Foo()),
|
|
("my_object2", SFoo()),
|
|
)
|
|
|
|
class M(torch.nn.Module):
|
|
# TODO: re-enable this once this test is in a Python 3-only syntax
|
|
# file
|
|
# my_empty_list : List[int]
|
|
# my_empty_dict : Dict[str, int]
|
|
# my_none : Optional[int]
|
|
|
|
def forward(self, x):
|
|
return (
|
|
self.my_dict,
|
|
self.my_float,
|
|
self.my_int,
|
|
self.my_bool,
|
|
# self.my_tensor,
|
|
self.my_int_list,
|
|
# self.my_tensor_list,
|
|
self.my_bool_list,
|
|
self.my_float_list,
|
|
self.my_str_list,
|
|
self.my_empty_list,
|
|
self.my_empty_dict,
|
|
self.my_none,
|
|
self.my_object.a,
|
|
self.my_object.inner.b,
|
|
self.my_object.a,
|
|
self.my_object2.inner.b,
|
|
)
|
|
|
|
# TODO: as a followup, fix this test
|
|
# We can't define class attributes like we should be doing:
|
|
# class M(torch.nn.Module):
|
|
# my_empty_list : List[int]
|
|
# my_empty_dict : Dict[str, int]
|
|
# my_none : Optional[int]
|
|
# my_out_of_line_attribute: List[int] = [1, 2, 3]
|
|
# since there's no string frontend for Python classes (so the `define`)
|
|
# trick doesn't work.
|
|
M.__annotations__ = {
|
|
"my_empty_list": List[int],
|
|
"my_empty_dict": Dict[str, int],
|
|
"my_none": Optional[int],
|
|
"my_object": Foo,
|
|
"my_object2": SFoo,
|
|
}
|
|
|
|
m = M()
|
|
for name, value in untyped_values + typed_values:
|
|
setattr(m, name, value)
|
|
|
|
self.checkModule(m, (torch.randn(5, 5),))
|
|
|
|
def test_function_attribute_in_submodule(self):
|
|
class N(nn.Module):
|
|
def __init__(self, norm):
|
|
super().__init__()
|
|
self.activation = torch.nn.functional.relu
|
|
self.norm = norm
|
|
|
|
def forward(self, src):
|
|
output = src
|
|
output = self.norm(output)
|
|
return output
|
|
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
encoder_norm = nn.ReLU()
|
|
self.encoder = N(encoder_norm)
|
|
|
|
def forward(self, x):
|
|
return self.encoder(x)
|
|
|
|
m = M()
|
|
self.checkModule(m, (torch.randn(5, 5),))
|
|
|
|
def test_inner_traced_module(self):
|
|
class Dummy(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, dummies):
|
|
super().__init__()
|
|
self._dummies = dummies
|
|
|
|
def forward(self, x):
|
|
out = []
|
|
for dummy in self._dummies:
|
|
out.append(dummy(x))
|
|
return out
|
|
|
|
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
|
|
dummies = nn.ModuleList([dummy])
|
|
model = Model(dummies)
|
|
self.checkModule(model, (torch.rand(5, 5),))
|
|
|
|
def test_script_loaded_module(self):
|
|
"""
|
|
Test that we can hold a loaded ScriptModule as a submodule.
|
|
"""
|
|
|
|
class Dummy(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
dummy = torch.jit.script(Dummy())
|
|
dummy = self.getExportImportCopy(dummy)
|
|
|
|
class ContainsLoaded(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.encoder = dummy
|
|
|
|
def forward(self, input):
|
|
return self.encoder(input)
|
|
|
|
self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
|
|
|
|
def test_optional_module(self):
|
|
class Dummy(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
if self.foo is not None:
|
|
return self.foo(x)
|
|
return x
|
|
|
|
mod = Dummy()
|
|
self.checkModule(mod, (torch.rand(2, 2),))
|
|
mod.foo = None
|
|
self.checkModule(mod, (torch.rand(2, 2),))
|
|
|
|
def test_thread_safe_error_stacks(self):
|
|
# prior to #160386, this causes a segfault. See [Note: Thread-safe CallStack]
|
|
callstacks = []
|
|
|
|
def callstack_creator():
|
|
factory = torch._C._jit_tree_views.SourceRangeFactory(
|
|
"source code", "a.py", 1, 0
|
|
)
|
|
x = torch._C.CallStack("a", factory.make_range(1, 0, 1))
|
|
callstacks.append(x)
|
|
del x
|
|
|
|
t = threading.Thread(target=callstack_creator)
|
|
t.start()
|
|
t.join()
|
|
del t
|
|
del callstacks[0]
|
|
self.assertTrue(len(callstacks) == 0)
|
|
|
|
def test_override_instance_method_ignore(self):
|
|
class M(torch.nn.Module):
|
|
@torch.jit.ignore
|
|
def i_am_ignored(self):
|
|
return "old"
|
|
|
|
m = M()
|
|
|
|
# Override the ignored method by binding a new method to this instance.
|
|
@torch.jit.ignore
|
|
def i_am_ignored(self):
|
|
return "new"
|
|
|
|
m.i_am_ignored = types.MethodType(i_am_ignored, m)
|
|
self.assertEqual(m.i_am_ignored(), "new")
|
|
|
|
# ScriptModule should correctly reflect the override.
|
|
s = torch.jit.script(m)
|
|
self.assertEqual(s.i_am_ignored(), "new")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|