mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: In Python-3.9 and above `inspect.getsource` of a local class does not work if it was marked as default, see https://bugs.python.org/issue42666 https://github.com/pytorch/pytorch/issues/49617 Workaround by defining `make_global` function that programmatically accomplishes the same Partially addresses issue raised in https://github.com/pytorch/pytorch/issues/49617 Pull Request resolved: https://github.com/pytorch/pytorch/pull/51088 Reviewed By: gmagogsfm Differential Revision: D26069189 Pulled By: malfet fbshipit-source-id: 7cf14b88ae5d2b95d2b0fd852717a9202b86356e
808 lines
25 KiB
Python
808 lines
25 KiB
Python
from collections import namedtuple
|
|
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
|
from torch.testing import FileCheck
|
|
from torch import jit
|
|
from textwrap import dedent
|
|
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
|
import inspect
|
|
import unittest
|
|
import sys
|
|
import torch
|
|
import torch.testing._internal.jit_utils
|
|
import torch.nn as nn
|
|
import types
|
|
|
|
class TestScriptPy3(JitTestCase):
|
|
def test_joined_str(self):
|
|
def func(x):
|
|
hello, test = "Hello", "test"
|
|
print(f"{hello + ' ' + test}, I'm a {test}") # noqa E999
|
|
print(f"format blank") # noqa F541
|
|
hi = 'hi'
|
|
print(f"stuff before {hi}")
|
|
print(f"{hi} stuff after")
|
|
return x + 1
|
|
|
|
x = torch.arange(4., requires_grad=True)
|
|
# TODO: Add support for f-strings in string parser frontend
|
|
# self.checkScript(func, [x], optimize=True, capture_output=True)
|
|
|
|
with self.capture_stdout() as captured:
|
|
out = func(x)
|
|
|
|
scripted = torch.jit.script(func)
|
|
with self.capture_stdout() as captured_script:
|
|
out_script = func(x)
|
|
|
|
self.assertEqual(out, out_script)
|
|
self.assertEqual(captured, captured_script)
|
|
|
|
@unittest.skipIf(sys.version_info[:2] < (3, 7), "`dataclasses` module not present on < 3.7")
|
|
def test_dataclass_error(self):
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class NormalizationInfo(object):
|
|
mean: float = 0.0
|
|
|
|
def compute(self, total_rows):
|
|
return self.mean
|
|
|
|
def fn():
|
|
return NormalizationInfo(1, 2, 3, 4, 5)
|
|
|
|
with self.assertRaisesRegex(OSError, "could not get source code"):
|
|
torch.jit.script(fn)
|
|
|
|
def test_optional_dict_construct(self):
|
|
class M(torch.nn.Module):
|
|
def use(self, buffer: Dict[str, Optional[torch.Tensor]]):
|
|
return buffer["prev_key"]
|
|
|
|
def forward(self, x):
|
|
prev_key = torch.rand(2, 3)
|
|
next_key = torch.rand(2, 3)
|
|
saved_state: Dict[str, Optional[torch.Tensor]] = {
|
|
"prev_key": prev_key,
|
|
"next_key": next_key,
|
|
}
|
|
|
|
return self.use(saved_state)
|
|
|
|
self.checkModule(M(), (torch.rand(2, 2),))
|
|
|
|
def test_kwarg_support(self):
|
|
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str = 2):
|
|
pass
|
|
torch.jit.script(M())
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str):
|
|
return n_tokens, device_name
|
|
|
|
sm = torch.jit.script(M())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
|
|
sm()
|
|
|
|
input = (3, 'hello')
|
|
self.assertEqual(sm(*input), input)
|
|
|
|
def test_named_tuple(self):
|
|
class FeatureVector(NamedTuple):
|
|
float_features: float
|
|
sequence_features: List[float]
|
|
time_since_first: float
|
|
|
|
@torch.jit.script
|
|
def foo(x) -> float:
|
|
fv = FeatureVector(3.0, [3.0], 3.0) # noqa
|
|
rv = fv.float_features
|
|
for val in fv.sequence_features:
|
|
rv += val
|
|
rv *= fv.time_since_first
|
|
return rv
|
|
|
|
self.assertEqual(foo(torch.rand(3, 4)), 18.0)
|
|
|
|
def test_named_tuple_constant(self):
|
|
class Tup(NamedTuple):
|
|
a: int
|
|
b: int
|
|
|
|
@torch.jit.script
|
|
def foo():
|
|
return Tup(1, 2)
|
|
|
|
self.assertEqual(foo(), Tup(1, 2))
|
|
|
|
def test_dict_preserves_order(self):
|
|
def dict_ordering():
|
|
a : Dict[int, int] = {}
|
|
for i in range(1000):
|
|
a[i] = i + 1
|
|
return a
|
|
|
|
self.checkScript(dict_ordering, ())
|
|
di = torch.jit.script(dict_ordering)()
|
|
res = list(di.items())
|
|
for i in range(1000):
|
|
key, value = res[i]
|
|
self.assertTrue(key == i and value == i + 1)
|
|
|
|
def test_list_unification_hint(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"):
|
|
@torch.jit.script
|
|
def x():
|
|
b : int = [2, 3]
|
|
return b
|
|
|
|
def test_return_named_tuple(self):
|
|
class FeatureVector(NamedTuple):
|
|
float_features: float
|
|
sequence_features: List[float]
|
|
time_since_first: float
|
|
|
|
@torch.jit.script
|
|
def foo(x):
|
|
fv = FeatureVector(3.0, [3.0], 3.0)
|
|
return fv
|
|
|
|
out = foo(torch.rand(3, 4))
|
|
out = foo(torch.rand(3, 4))
|
|
self.assertEqual(out.float_features, 3.0)
|
|
self.assertEqual(out.sequence_features, [3.0])
|
|
self.assertEqual(out.time_since_first, 3.0)
|
|
|
|
def test_named_tuple_as_attr(self):
|
|
class Config(NamedTuple):
|
|
size: int
|
|
|
|
class MyMod(nn.Module):
|
|
configs: Dict[int, Config]
|
|
|
|
def __init__(self, configs):
|
|
super().__init__()
|
|
self.configs = configs
|
|
|
|
def forward(self, x):
|
|
for _id, config in self.configs.items():
|
|
x += config.size
|
|
return x
|
|
|
|
s = torch.jit.script(MyMod({0: Config(size=16)}))
|
|
|
|
def test_types_as_values(self):
|
|
def fn(m: torch.Tensor) -> torch.device:
|
|
return m.device
|
|
|
|
self.checkScript(fn, [torch.randn(2, 2)])
|
|
|
|
GG = namedtuple('GG', ['f', 'g'])
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@torch.jit.ignore
|
|
def foo(self, x: torch.Tensor, z: torch.Tensor) -> Tuple[GG, GG]:
|
|
return GG(x, z), GG(x, z)
|
|
|
|
def forward(self, x, z):
|
|
return self.foo(x, z)
|
|
|
|
foo = torch.jit.script(Foo())
|
|
y = foo(torch.randn(2, 2), torch.randn(2, 2))
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@torch.jit.ignore
|
|
def foo(self, x, z) -> Tuple[GG, GG]:
|
|
return GG(x, z)
|
|
|
|
def forward(self, x, z):
|
|
return self.foo(x, z)
|
|
|
|
foo = torch.jit.script(Foo())
|
|
y = foo(torch.randn(2, 2), torch.randn(2, 2))
|
|
|
|
|
|
def test_named_tuple_resolution(self):
|
|
class TheType(NamedTuple):
|
|
t: int
|
|
|
|
class MyModule(types.ModuleType):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__('MyModule')
|
|
|
|
def __getattr__(self, attr):
|
|
return TheType
|
|
|
|
some_module = MyModule()
|
|
|
|
def fn() -> some_module.Type:
|
|
return some_module.Type(1)
|
|
|
|
self.checkScript(fn, [])
|
|
|
|
def test_ignore_with_types(self):
|
|
@torch.jit.ignore
|
|
def fn(x: Dict[str, Optional[torch.Tensor]]):
|
|
return x + 10
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
|
|
def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
|
|
self.dropout_modality(in_batch)
|
|
fn(in_batch)
|
|
return torch.tensor(1)
|
|
|
|
@torch.jit.ignore
|
|
def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
|
|
return in_batch
|
|
|
|
sm = torch.jit.script(M())
|
|
FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
|
|
|
|
def test_python_callable(self):
|
|
class MyPythonClass(object):
|
|
@torch.jit.ignore
|
|
def __call__(self, *args) -> str:
|
|
return str(type(args[0]))
|
|
|
|
the_class = MyPythonClass()
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
return the_class(x)
|
|
|
|
# This doesn't involve the string frontend, so don't use checkScript
|
|
x = torch.ones(2)
|
|
self.assertEqual(fn(x), the_class(x))
|
|
|
|
def test_bad_types(self):
|
|
@torch.jit.ignore
|
|
def fn(my_arg):
|
|
return my_arg + 10
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
|
|
@torch.jit.script
|
|
def other_fn(x):
|
|
return fn('2')
|
|
|
|
def test_named_tuple_slice_unpack(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo(a : int, b : float, c : List[int]):
|
|
tup = MyCoolNamedTuple(a, b, c) # noqa
|
|
my_a, my_b, my_c = tup
|
|
return tup[:1], my_a, my_c
|
|
|
|
self.assertEqual(foo(3, 3.5, [6]), ((3,), 3, [6]))
|
|
|
|
def test_named_tuple_lower(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo(a : int):
|
|
tup = MyCoolNamedTuple(a, 3.14, [9]) # noqa
|
|
return tup
|
|
|
|
FileCheck().check('TupleConstruct').run(foo.graph)
|
|
torch._C._jit_pass_lower_all_tuples(foo.graph)
|
|
FileCheck().check_not('TupleConstruct').run(foo.graph)
|
|
|
|
def test_named_tuple_type_annotation(self):
|
|
global MyCoolNamedTuple # see [local resolution in python]
|
|
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo(x : MyCoolNamedTuple) -> MyCoolNamedTuple:
|
|
return x
|
|
|
|
mnt = MyCoolNamedTuple(42, 420.0, [666])
|
|
self.assertEqual(foo(mnt), mnt)
|
|
|
|
def test_named_tuple_wrong_types(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int' for argument 'a'"
|
|
" but instead found type 'str'"):
|
|
@torch.jit.script
|
|
def foo():
|
|
tup = MyCoolNamedTuple('foo', 'bar', 'baz') # noqa
|
|
return tup
|
|
|
|
def test_named_tuple_kwarg_construct(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo():
|
|
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
|
|
return tup
|
|
|
|
tup = foo()
|
|
self.assertEqual(tup.a, 9)
|
|
self.assertEqual(tup.b, 3.5)
|
|
self.assertEqual(tup.c, [1, 2, 3])
|
|
|
|
def test_named_tuple_default_error(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int] = [3, 4, 5]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Default values are currently not supported'):
|
|
@torch.jit.script
|
|
def foo():
|
|
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
|
|
return tup
|
|
|
|
@unittest.skipIf(True, "broken while these tests were not in CI")
|
|
def test_named_tuple_serialization(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
class MyMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return MyCoolNamedTuple(3, 3.5, [3, 4, 5])
|
|
|
|
mm = MyMod()
|
|
mm.save('foo.zip')
|
|
torch.testing._internal.jit_utils.clear_class_registry()
|
|
loaded = torch.jit.load('foo.zip')
|
|
|
|
out = mm()
|
|
out_loaded = loaded()
|
|
|
|
for name in ['a', 'b', 'c']:
|
|
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
|
|
|
|
def test_type_annotate_py3(self):
|
|
def fn():
|
|
a : List[int] = []
|
|
b : torch.Tensor = torch.ones(2, 2)
|
|
c : Optional[torch.Tensor] = None
|
|
d : Optional[torch.Tensor] = torch.ones(3, 4)
|
|
for _ in range(10):
|
|
a.append(4)
|
|
c = torch.ones(2, 2)
|
|
d = None
|
|
return a, b, c, d
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def wrong_type():
|
|
wrong : List[int] = [0.5]
|
|
return wrong
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
|
|
torch.jit.script(wrong_type)
|
|
|
|
def test_optional_no_element_type_annotation(self):
|
|
"""
|
|
Test that using an optional with no contained types produces an error.
|
|
"""
|
|
def fn_with_comment(x: torch.Tensor) -> Optional:
|
|
return (x, x)
|
|
|
|
def annotated_fn(x: torch.Tensor) -> Optional:
|
|
return (x, x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
|
cu = torch.jit.CompilationUnit()
|
|
cu.define(dedent(inspect.getsource(fn_with_comment)))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
|
cu = torch.jit.CompilationUnit()
|
|
cu.define(dedent(inspect.getsource(annotated_fn)))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
|
torch.jit.script(fn_with_comment)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
|
torch.jit.script(annotated_fn)
|
|
|
|
def test_tuple_no_element_type_annotation(self):
|
|
"""
|
|
Test that using a tuple with no contained types produces an error.
|
|
"""
|
|
def fn_with_comment(x: torch.Tensor) -> Tuple:
|
|
return (x, x)
|
|
|
|
def annotated_fn(x: torch.Tensor) -> Tuple:
|
|
return (x, x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
|
cu = torch.jit.CompilationUnit()
|
|
cu.define(dedent(inspect.getsource(fn_with_comment)))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
|
cu = torch.jit.CompilationUnit()
|
|
cu.define(dedent(inspect.getsource(annotated_fn)))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
|
torch.jit.script(fn_with_comment)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
|
torch.jit.script(annotated_fn)
|
|
|
|
def test_tuple_subscripted_assign(self):
|
|
with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
|
|
@torch.jit.script
|
|
def foo(a: Tuple[int, int]) -> None:
|
|
a[0] = a[1]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
|
|
@torch.jit.script
|
|
def bar(a: Tuple[int, int]) -> None:
|
|
a[0] += a[1]
|
|
|
|
def test_subexpression_List_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
|
|
return x[0]
|
|
|
|
FileCheck().check('Future[int]').check('Future[int]').run(fn.graph)
|
|
|
|
def test_subexpression_Future_annotate(self):
|
|
@torch.jit.script
|
|
def fn() -> torch.jit.Future[int]:
|
|
x: List[torch.jit.Future[int]] = []
|
|
return x[0]
|
|
|
|
FileCheck().check("Future[int][]").run(fn.graph)
|
|
|
|
def test_future_isinstance(self):
|
|
@torch.jit.script
|
|
def fn(x: Any) -> torch.jit.Future[int]:
|
|
assert isinstance(x, jit.Future[int])
|
|
return x
|
|
|
|
FileCheck().check("Future[int]").run(fn.graph)
|
|
|
|
def test_str_refine_any(self):
|
|
def forward(x: Any) -> str:
|
|
if isinstance(x, str):
|
|
return x
|
|
return "foo"
|
|
forward = torch.jit.script(forward)
|
|
self.assertEqual(forward(1), "foo")
|
|
self.assertEqual(forward("bar"), "bar")
|
|
|
|
def test_subexpression_Tuple_int_int_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Tuple[int, int, torch.jit.Future[int]]) -> Tuple[int, torch.jit.Future[int]]:
|
|
return x[0], x[2]
|
|
|
|
FileCheck().check('(int, int, Future[int])').check('(int, Future[int])').run(fn.graph)
|
|
|
|
def test_subexpression_Dict_int_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
|
|
return x[y]
|
|
|
|
FileCheck().check('Dict(int, Future(int))').check('Future[int]').run(fn.graph)
|
|
|
|
def test_subexpression_Optional(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Optional[Dict[int, torch.jit.Future[int]]]) -> Optional[torch.jit.Future[int]]:
|
|
if x is not None:
|
|
return x[0]
|
|
else:
|
|
return None
|
|
|
|
FileCheck().check('Dict(int, Future(int))?').run(fn.graph)
|
|
|
|
def test_unimported_type_resolution(self):
|
|
# verify fallback from the python resolver to the c++ resolver
|
|
|
|
@ torch.jit.script
|
|
def fn(x):
|
|
# type: (number) -> number
|
|
return x + 1
|
|
|
|
FileCheck().check('Scalar').run(fn.graph)
|
|
|
|
def test_parser_bug(self):
|
|
def parser_bug(o: Optional[torch.Tensor]):
|
|
pass
|
|
|
|
def test_mismatched_annotation(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'annotated with type'):
|
|
@torch.jit.script
|
|
def foo():
|
|
x : str = 4
|
|
return x
|
|
|
|
def test_reannotate(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'):
|
|
@torch.jit.script
|
|
def foo():
|
|
x = 5
|
|
if 1 == 1:
|
|
x : Optional[int] = 7
|
|
|
|
def test_module_inplace_construct(self):
|
|
class M(nn.Module):
|
|
def __init__(self, start: int):
|
|
super().__init__()
|
|
self.linear = nn.Linear(3, 3)
|
|
self.attribute = start
|
|
self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float))
|
|
|
|
def method(self) -> int:
|
|
return self.attribute
|
|
|
|
@torch.jit.unused
|
|
def unused_method(self):
|
|
return self.attribute + self.attribute
|
|
|
|
def forward(self, x):
|
|
return self.linear(self.linear(x))
|
|
|
|
|
|
class N(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(4, 4)
|
|
|
|
@torch.jit.ignore
|
|
def ignored_method(self, x):
|
|
return x
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
m = torch.jit.script(M(3))
|
|
n = torch.jit.script(N())
|
|
|
|
n._reconstruct(m._c)
|
|
|
|
inp = torch.rand((3))
|
|
|
|
# Check that both modules produce the same output.
|
|
with torch.no_grad():
|
|
m_out = m(inp)
|
|
n_out = n(inp)
|
|
self.assertEqual(m_out, n_out)
|
|
|
|
# Check that ignored method is still intact.
|
|
self.assertEqual(inp, n.ignored_method(inp))
|
|
|
|
def test_if_returning_any(self):
|
|
"""
|
|
Check that an if statement can return different
|
|
types early from each branch when the return
|
|
type of the function is Any.
|
|
"""
|
|
def if_function(inp: torch.Tensor) -> Any:
|
|
if inp.shape[0] == 1:
|
|
return inp * inp
|
|
else:
|
|
return "str"
|
|
|
|
self.checkScript(if_function, (torch.randn(5),))
|
|
|
|
def test_module_properties(self):
|
|
class ModuleWithProperties(torch.nn.Module):
|
|
__jit_unused_properties__ = ["ignored_attr"]
|
|
|
|
def __init__(self, a: int):
|
|
super().__init__()
|
|
self.a = a
|
|
|
|
def forward(self, a: int, b: int):
|
|
self.attr = a + b
|
|
return self.attr
|
|
|
|
@property
|
|
def attr(self):
|
|
return self.a
|
|
|
|
@property
|
|
def ignored_attr(self):
|
|
return sum([self.a])
|
|
|
|
@torch.jit.unused
|
|
@property
|
|
def ignored_attr_2(self):
|
|
return sum([self.a])
|
|
|
|
@ignored_attr_2.setter
|
|
def ignored_attr_2(self, value):
|
|
self.a = sum([self.a])
|
|
|
|
@attr.setter
|
|
def attr(self, a: int):
|
|
if a > 0:
|
|
self.a = a
|
|
else:
|
|
self.a = 0
|
|
|
|
class ModuleWithNoSetter(torch.nn.Module):
|
|
def __init__(self, a: int):
|
|
super().__init__()
|
|
self.a = a
|
|
|
|
def forward(self, a: int, b: int):
|
|
self.attr + a + b
|
|
|
|
@property
|
|
def attr(self):
|
|
return self.a + 1
|
|
|
|
self.checkModule(ModuleWithProperties(5), (5, 6,))
|
|
self.checkModule(ModuleWithProperties(5), (-5, -6,))
|
|
self.checkModule(ModuleWithNoSetter(5), (5, 6,))
|
|
self.checkModule(ModuleWithNoSetter(5), (-5, -6,))
|
|
|
|
mod = ModuleWithProperties(3)
|
|
scripted_mod = torch.jit.script(mod)
|
|
|
|
with self.assertRaisesRegex(AttributeError, "has no attribute"):
|
|
scripted_mod.ignored_attr
|
|
|
|
def test_ignoring_module_attributes(self):
|
|
"""
|
|
Test that module attributes can be ignored.
|
|
"""
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, a: int) -> int:
|
|
return sum([a])
|
|
|
|
class ModuleWithIgnoredAttr(torch.nn.Module):
|
|
__jit_ignored_attributes__ = ["a", "sub"]
|
|
|
|
def __init__(self, a: int, b: int):
|
|
super().__init__()
|
|
self.a = a
|
|
self.b = b
|
|
self.sub = Sub()
|
|
|
|
def forward(self) -> int:
|
|
return self.b
|
|
|
|
@torch.jit.ignore
|
|
def ignored_fn(self) -> int:
|
|
return self.sub.forward(self.a)
|
|
|
|
mod = ModuleWithIgnoredAttr(1, 4)
|
|
scripted_mod = torch.jit.script(mod)
|
|
self.assertEqual(scripted_mod(), 4)
|
|
self.assertEqual(scripted_mod.ignored_fn(), 1)
|
|
|
|
# Test the error message for ignored attributes.
|
|
class ModuleUsesIgnoredAttr(torch.nn.Module):
|
|
__jit_ignored_attributes__ = ["a", "sub"]
|
|
|
|
def __init__(self, a: int):
|
|
super().__init__()
|
|
self.a = a
|
|
self.sub = Sub()
|
|
|
|
def forward(self) -> int:
|
|
return self.sub(self.b)
|
|
|
|
mod = ModuleUsesIgnoredAttr(1)
|
|
|
|
with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"):
|
|
scripted_mod = torch.jit.script(mod)
|
|
|
|
|
|
def test_export_opnames_interface(self):
|
|
|
|
@torch.jit.interface
|
|
class OneTwoModule(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
class FooMod(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 * x
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.one(self.two(x), x)
|
|
|
|
class BarMod(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x * y
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 / x
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.two(self.one(x, x))
|
|
|
|
make_global(OneTwoModule)
|
|
|
|
class M(nn.Module):
|
|
sub : OneTwoModule
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.sub = BarMod()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.sub.forward(x)
|
|
|
|
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
|
|
return mod_list[0].forward(x) + mod_list[1].forward(x)
|
|
|
|
scripted_M_mod = torch.jit.script(M())
|
|
# Temporarily test empty output because lite interpreter does not support interface call
|
|
# Replace it with the issubset call when interface call is supported.
|
|
self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0)
|
|
# self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset(
|
|
# set(torch.jit.export_opnames(scripted_M_mod))))
|
|
|
|
scripted_M_mod.sub = torch.jit.script(FooMod())
|
|
self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0)
|
|
# self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset(
|
|
# set(torch.jit.export_opnames(scripted_M_mod))))
|
|
|
|
def test_broadcasting_list(self):
|
|
"""
|
|
Test BroadcastingList and torch.nn._size_N_t alias
|
|
"""
|
|
from torch._jit_internal import BroadcastingList2
|
|
from torch.nn.common_types import _size_2_t
|
|
|
|
def sum_i(x: _size_2_t) -> int:
|
|
return x[0] + x[1]
|
|
|
|
def sum_f(x: BroadcastingList2[float]) -> float:
|
|
return x[0] + x[1]
|
|
|
|
self.assertTrue(torch.jit.script(sum_i)(4) == 8)
|
|
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|