import unittest import os import sys from typing import List, Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.testing import FileCheck # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from jit_utils import JitTestCase, _tmp_donotuse_dont_inline_everything class TestRecursiveScript(JitTestCase): def test_inferred_nonetype(self): class M(nn.Module): def __init__(self): super(M, self).__init__() self.x = None def forward(self): assert self.x is None m = torch.jit.script(M()) self.checkModule(M(), ()) def test_script_function_attribute(self): @torch.jit.script def fn1(x): return x + x @torch.jit.script def fn2(x): return x - x class M(torch.nn.Module): def __init__(self, fn): super(M, self).__init__() self.fn = fn def forward(self, x): return self.fn(x) fn1_mod = M(fn1) fn2_mod = M(fn2) self.checkModule(fn1_mod, (torch.randn(2, 2),)) self.checkModule(fn2_mod, (torch.randn(2, 2),)) def test_python_function_attribute(self): class M(torch.nn.Module): def __init__(self, fn): super(M, self).__init__() self.fn = fn def forward(self, x): return self.fn(x) mod = M(F.sigmoid) self.checkModule(mod, (torch.randn(2, 2),)) def test_failed_function_compilation(self): def fn(x): return i_dont_exist class M(torch.nn.Module): def __init__(self, fn): super(M, self).__init__() self.fn = fn def forward(self, x): return self.fn(x) m = M(fn) with self.assertRaisesRegex(RuntimeError, "failed to compile"): torch.jit.script(m) def test_init_error(self): class M(nn.Module): def __init__(self): self.x = 2 def forward(self): pass with self.assertRaisesRegex(RuntimeError, "has not been initialized"): torch.jit.script(M()) def test_script_after_eval(self): class M(nn.Module): def forward(self): if self.training: return 2 else: return 0 m = M() sm1 = torch.jit.script(m) m.eval() sm2 = torch.jit.script(m) # m is in eval mode, training should be False self.assertFalse(m.training) # sm1 was created while m had training = True self.assertTrue(sm1.training) self.assertEqual(sm1.training, sm1._c._get_attribute('training')) self.assertEqual(sm1(), 2) # sm2 was created after m was eval'ed self.assertFalse(sm2.training) self.assertEqual(sm2.training, sm2._c._get_attribute('training')) self.assertEqual(sm2(), 0) def test_module_name(self): class MyModule(torch.nn.Module): def __init__(self): super(MyModule, self).__init__() self.x = 2 def forward(self, t): return t + self.x m = torch.jit.script(MyModule()) FileCheck().check("ClassType").run(m.graph) def test_repeated_error_stack(self): def d(x): return "a" - 2 def c(x): return d(x) def b(x): return c(x) def a(x): return b(x) try: torch.jit.script(a) except Exception as e: FileCheck().check_count("is being compiled", 2).run(str(e)) try: torch.jit.script(a) except Exception as e: # Make sure that no entries are left over from the previous failure FileCheck().check_count("is being compiled", 2).run(str(e)) @unittest.skipIf(True, "Class annotations are a thing in > 3.5, need to fix for < 3.7") def test_constants_with_final(self): class M(torch.nn.Module): # TODO: Use this (see below) # x : torch.jit.Final[int] def __init__(self): super(M, self).__init__() self.x = 2 def forward(self, t): return t + self.x # TODO: Fix this test so that we can actually define the class like # class M(torch.nn.Module): # x : torch.jit.Final[int] M.__annotations__ = {'x': torch.jit.Final[int]} m = M() self.checkModule(M(), (torch.randn(2, 2),)) def test_ignore_class(self): @torch.jit.ignore class MyScriptClass(object): def unscriptable(self): return "a" + 200 class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() def forward(self, x): return MyScriptClass() with self.assertRaisesRegex(torch.jit.frontend.FrontendError, "Cannot instantiate class"): t = torch.jit.script(TestModule()) def test_method_call(self): class M(nn.Module): def test(self, x): return x def forward(self, z): y = self.test(z) return z + 20 + y self.checkModule(M(), (torch.randn(2, 2),)) def test_module_repr(self): class Submodule(nn.Module): def forward(self, x): return x class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv = nn.Conv2d(10, 10, 3) self.lin = nn.Linear(10, 10) self.sub = Submodule() def forward(self, x): return self.lin(x) + self.sub(x) + self.conv(x) m = torch.jit.script(MyModule()) with self.capture_stdout() as out: print(m) f = FileCheck() f.check('MyModule') f.check('Conv2d') f.check('Linear') f.check('Submodule') f.run(out[0]) self.assertEqual(m.original_name, 'MyModule') def test_class_compile(self): def other_fn(a, b): # type: (int, Tensor) -> Tensor return a * b class B(object): def __init__(self, x): self.x = 2 def helper(self, a): return self.x + a + other_fn(self.x, a) class N(torch.nn.Module): def __init__(self): super(N, self).__init__() def forward(self, x): b = B(x) return b.helper(x) self.checkModule(N(), (torch.randn(2, 2),)) def test_error_stack(self): def d(x): # type: (int) -> int return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) def a(x): return b(x) try: scripted = torch.jit.script(a) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("def c(x)") checker.check("def b(x)") checker.check("def a(x)") checker.run(str(e)) def test_error_stack_module(self): def d(x): # type: (int) -> int return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) class Submodule(torch.nn.Module): def __init__(self): super(Submodule, self).__init__() def forward(self, x): return b(x) class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.submodule = Submodule() def some_method(self, y): return y + self.submodule(y) def forward(self, x): return self.some_method(x) try: scripted = torch.jit.script(M()) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("'c' is being compiled since it was called from 'b'") checker.check("'b' is being compiled since it was called from") checker.run(str(e)) @_tmp_donotuse_dont_inline_everything def test_script_basic(self): def a_python_fn(a, b, c): return a + b + c @torch.jit.script def a_script_fn(d, e, f): return a_python_fn(d, e, f) graph = str(a_script_fn.graph) FileCheck().check("prim::CallFunction").run(graph) FileCheck().check_not("^a_python_fn").run(graph) t = torch.ones(2, 2) self.assertEqual(a_script_fn(t, t, t), t + t + t) def test_error_stack_class(self): class X(object): def bad_fn(self): import pdb # noqa def fn(x): return X(10) try: torch.jit.script(fn) except Exception as e: checker = FileCheck() checker.check("import statements") checker.check("is being compiled since it was called from") checker.run(str(e)) def test_module_basic(self): class Other(torch.nn.Module): __constants__ = ['x'] def __init__(self, x): super(Other, self).__init__() self.x = x self.param = torch.nn.Parameter(torch.ones(2, 2)) def some_unscriptable_method(self): a = 2 a = [2] return a def forward(self, t): return t + self.x + self.param class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.other = Other(200) def forward(self, t): return self.other(t) * 2 self.checkModule(M(), (torch.ones(2, 2),)) def test_module_function_export(self): class Other(torch.nn.Module): __constants__ = ['x'] def __init__(self, x): super(Other, self).__init__() self.x = x self.param = torch.nn.Parameter(torch.ones(2, 2)) @torch.jit.export def some_entry_point(self, y): return y + 20 def forward(self, t): return t + self.x + self.param class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.other = Other(200) def forward(self, t): return self.other(t) * 2 self.checkModule(M(), (torch.ones(2, 2),)) def test_iterable_modules(self): class Inner(torch.nn.Module): def forward(self, x): return x + 10 class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.sequential = nn.Sequential( Inner(), Inner(), nn.Sequential(Inner(), Inner()) ) self.module_list = nn.ModuleList([Inner(), Inner()]) def forward(self, x): for mod in self.module_list: x += mod(x) x += self.sequential(x) return x self.checkModule(M(), (torch.randn(5, 5),)) def test_attributes(self): @torch.jit.script class Inner(object): def __init__(self): self.b = "a string" @torch.jit.script class Foo(object): def __init__(self): self.a = 4 self.inner = Inner() @torch.jit.script class SFoo(object): def __init__(self): self.a = 4 self.inner = Inner() def __setstate__(self, obj): # type: (Tuple[int, Inner]) -> None a, inner = obj self.a = a self.inner = inner def __getstate__(self): return (self.a, self.inner) untyped_values = ( ('my_dict', {"I": "am", "a test": "test"}), ('my_float', 2.3), ('my_int', 99), ('my_bool', False), ('my_tuple', (1, 2, 3, 4)), ('my_list', [(1, 2), (3, 4)]), # ('my_tensor', torch.randn(2, 2)), ('my_int_list', [1, 2, 3, 4]), # ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]), ('my_bool_list', [True, True, False, True]), ('my_float_list', [1., 2., 3., 4.]), ('my_str_list', ['hello', 'bye']), ) typed_values = ( ('my_empty_list', []), ('my_empty_dict', {}), ('my_none', None), ('my_object', Foo()), ('my_object2', SFoo()), ) class M(torch.nn.Module): # TODO: re-enable this once this test is in a Python 3-only syntax # file # my_empty_list : List[int] # my_empty_dict : Dict[str, int] # my_none : Optional[int] def __init__(self): super(M, self).__init__() def forward(self, x): return ( self.my_dict, self.my_float, self.my_int, self.my_bool, # self.my_tensor, self.my_int_list, # self.my_tensor_list, self.my_bool_list, self.my_float_list, self.my_str_list, self.my_empty_list, self.my_empty_dict, self.my_none, self.my_object.a, self.my_object.inner.b, self.my_object.a, self.my_object2.inner.b, ) # TODO: as a followup, fix this test # We can't define class attributes like we should be doing: # class M(torch.nn.Module): # my_empty_list : List[int] # my_empty_dict : Dict[str, int] # my_none : Optional[int] # my_out_of_line_attribute: List[int] = [1, 2, 3] # since there's no string frontend for Python classes (so the `define`) # trick doesn't work. M.__annotations__ = { 'my_empty_list': List[int], 'my_empty_dict': Dict[str, int], 'my_none': Optional[int], 'my_object': Foo, 'my_object2': SFoo, } m = M() for name, value in untyped_values + typed_values: setattr(m, name, value) self.checkModule(m, (torch.randn(5, 5),)) if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead.")