mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This PR is created to replace https://github.com/pytorch/pytorch/pull/53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64234 Reviewed By: gmagogsfm Differential Revision: D30656444 Pulled By: ansley fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
672 lines
22 KiB
Python
672 lines
22 KiB
Python
import os
|
|
import sys
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from torch.testing._internal.common_utils import IS_WINDOWS
|
|
from collections import namedtuple
|
|
from typing import List, Tuple, Optional, Dict
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
class TestTyping(JitTestCase):
|
|
def test_dict_in_not_in(self):
|
|
def test_in_dict(x):
|
|
# type: (Dict[str, int]) -> bool
|
|
return 'hi' in x
|
|
|
|
self.checkScript(test_in_dict, ({'hi': 2, 'bye': 3},))
|
|
self.checkScript(test_in_dict, ({'bye': 3},))
|
|
|
|
# Check evaluation order
|
|
@torch.jit.script
|
|
def a():
|
|
print("a")
|
|
return 3
|
|
|
|
@torch.jit.script
|
|
def b():
|
|
print("b")
|
|
return {3: 2, 4: 1}
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
return a() in b()
|
|
|
|
with self.capture_stdout() as captured:
|
|
self.assertTrue(fn())
|
|
if not IS_WINDOWS:
|
|
# no stdout capturing on windows
|
|
self.assertEqual(captured[0], "a\nb\n")
|
|
|
|
def test_not_in_dict(a):
|
|
# type: (Dict[str, int]) -> bool
|
|
if "hello" not in a:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2}, ))
|
|
self.checkScript(test_not_in_dict, ({"world": 2}, ))
|
|
|
|
def test_dict_tensor_key(a, t):
|
|
# type: (Dict[Tensor, int], Tensor) -> bool
|
|
if t in a:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
inp1 = torch.tensor(3)
|
|
inp2 = torch.tensor(5)
|
|
dict_a = {inp1: 1, inp2: 3}
|
|
self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(4)))
|
|
self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(3)))
|
|
self.checkScript(test_dict_tensor_key, (dict_a, inp1))
|
|
self.checkScript(test_dict_tensor_key, (dict_a, inp2))
|
|
|
|
def test_list_type_refinement_defaults_to_Any_list_creation(self):
|
|
def fn(x):
|
|
tup1 = ("foo", torch.tensor(2))
|
|
tup2 = ("bar", {"23": torch.tensor(3)})
|
|
tup3 = ("baz", x)
|
|
l = list((tup1, tup2)) # noqa: C410
|
|
l.append(tup3)
|
|
tup4 = l[0]
|
|
if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]):
|
|
t = tup4[1]
|
|
if isinstance(t, torch.Tensor):
|
|
l[0] = (tup4[0], torch.add(t, t))
|
|
return l
|
|
|
|
self.checkScript(fn, (torch.arange(5),))
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
|
|
# Check that we're making a `List[Tuple[str, Any]]`
|
|
FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])"
|
|
"[] = prim::ListConstruct()").run(graph)
|
|
|
|
def test_list_type_refinement_defaults_to_Any_list_comprehension(self):
|
|
def fn(x):
|
|
tup1 = ("foo", torch.tensor(2))
|
|
tup2 = ("bar", {"23": torch.tensor(3)})
|
|
tup3 = ("baz", x)
|
|
l_ = [tup1, tup2]
|
|
l = [t for t in l_] # noqa: C416
|
|
l.append(tup3)
|
|
tup4 = l[0]
|
|
if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]):
|
|
t = tup4[1]
|
|
if isinstance(t, torch.Tensor):
|
|
l[0] = (tup4[0], torch.add(t, t))
|
|
return l
|
|
|
|
self.checkScript(fn, (torch.arange(5),))
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
|
|
# Check that we're making a `List[Tuple[str, Any]]`
|
|
FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])"
|
|
"[] = prim::ListConstruct()").run(graph)
|
|
|
|
def test_list_type_refinement_annotation_element_mismatch(self):
|
|
def fn():
|
|
l: List[int] = [1, 2, "foo", 3]
|
|
return l
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "List type annotation"
|
|
r" `List\[int\]` did not match the "
|
|
"types of the given list elements"):
|
|
torch.jit.script(fn)
|
|
|
|
def test_dict_type_refinement_defaults_to_Any_dict_creation(self):
|
|
def fn(x):
|
|
d = dict(foo=torch.tensor(2),
|
|
bar={"23": torch.tensor(3)})
|
|
d["baz"] = x
|
|
t = d["foo"]
|
|
if isinstance(t, torch.Tensor):
|
|
d["bar"] = torch.add(t, t)
|
|
return d
|
|
|
|
self.checkScript(fn, (torch.arange(5),))
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
|
|
FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
|
|
" = prim::DictConstruct").run(graph)
|
|
|
|
def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self):
|
|
def fn(x):
|
|
d = {"foo": torch.tensor(2),
|
|
"bar": {"23": torch.tensor(3)}}
|
|
d["baz"] = x
|
|
t = d["foo"]
|
|
if isinstance(t, torch.Tensor):
|
|
d["bar"] = torch.add(t, t)
|
|
return d
|
|
|
|
self.checkScript(fn, (torch.arange(5),))
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
|
|
FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
|
|
" = prim::DictConstruct").run(graph)
|
|
|
|
def test_dict_type_refinement_annotation_key_mismatch(self):
|
|
def fn():
|
|
l1 = [1, 2, "foo", 3]
|
|
l2 = ["foo", "bar", "baz", "qux"]
|
|
d: Dict[int, str] = {k : v for k, v in zip(l1, l2)}
|
|
return l
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
|
|
r" `Dict\[int, str\]` did not match"
|
|
" the type of an actual key type"):
|
|
torch.jit.script(fn)
|
|
|
|
def test_dict_type_refinement_annotation_value_mismatch(self):
|
|
def fn():
|
|
l1 = ["foo", "bar", "baz", "qux"]
|
|
l2 = [1, 2, "foo", 3]
|
|
d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
|
|
return l
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
|
|
r" `Dict\[str, int\]` did not match"
|
|
" the type of an actual value "
|
|
"type"):
|
|
torch.jit.script(fn)
|
|
|
|
def test_dict_invalid_annotations(self):
|
|
# Check for invalid value type annotation
|
|
def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]):
|
|
return
|
|
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
|
|
torch.jit.script(wrong_value_type)
|
|
|
|
# Check for invalid key type annotation
|
|
def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]):
|
|
return
|
|
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
|
|
torch.jit.script(wrong_key_type)
|
|
|
|
# Check for invalid key and value type annotation
|
|
def wrong_key_value_type(dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]):
|
|
return
|
|
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
|
|
torch.jit.script(wrong_key_value_type)
|
|
|
|
def test_tuple_specialization(self):
|
|
@torch.jit.script
|
|
def f(t, s):
|
|
# type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor
|
|
x, t2 = t
|
|
_, y = t2
|
|
return x + y
|
|
|
|
t = torch.randn(2, 2), (1, torch.randn(2, 2)),
|
|
f(t, "hi")
|
|
graph = f.graph_for(t, "hi")
|
|
input_types = list(next(graph.inputs()).type().elements())
|
|
w = input_types[0]
|
|
self.assertEqual(input_types[0].kind(), 'TensorType')
|
|
self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType')
|
|
|
|
def test_tuple_io(self):
|
|
def stuff(x):
|
|
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
|
|
a, b = x
|
|
return b, a
|
|
|
|
a = (torch.rand(3), torch.rand(3))
|
|
self.checkScript(stuff, (a,))
|
|
|
|
def test_tuple_keyword(self):
|
|
def bar():
|
|
f = tuple((1, 2)) # noqa: C409
|
|
return f
|
|
|
|
self.checkScript(bar, ())
|
|
|
|
def foo():
|
|
return tuple(1, 2)
|
|
|
|
self.checkScriptRaisesRegex(foo, (), Exception,
|
|
"1 argument")
|
|
|
|
def cant_infer_size():
|
|
return tuple([1, 2, 3]) # noqa: C409
|
|
|
|
with self.assertRaisesRegex(Exception, "cannot statically infer the expected"):
|
|
torch.jit.script(cant_infer_size)
|
|
|
|
def test_tuple_create_return(self):
|
|
def stuff2(x):
|
|
# type: (int) -> Tuple[Tensor, Tensor]
|
|
a = (torch.ones(x), torch.zeros(x))
|
|
return a
|
|
self.checkScript(stuff2, (3,))
|
|
|
|
def test_list_io(self):
|
|
def stuff3(x):
|
|
# type: (List[int]) -> Tuple[Tensor, List[int]]
|
|
return torch.ones(x), x
|
|
self.checkScript(stuff3, ([3, 2],))
|
|
|
|
def test_bool_list_io(self):
|
|
@torch.jit.script
|
|
def stuff4(x):
|
|
# type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
|
|
return x, [True, False], [[True]]
|
|
|
|
li_1, li_2, li_3 = stuff4([True])
|
|
li_3 = li_3[0]
|
|
for li in [li_1, li_2, li_3]:
|
|
self.assertTrue(type(li[0]) == type(True))
|
|
|
|
def test_nested_list(self):
|
|
def foo(z):
|
|
# type: (Tuple[int, List[List[int]]]) -> int
|
|
x, y = z
|
|
return y[0][1]
|
|
self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
|
|
|
|
def test_list_sum(self):
|
|
def fn(x: List[int]) -> int:
|
|
return sum(x)
|
|
|
|
def fn1(x: List[float]):
|
|
return sum(x)
|
|
|
|
def fn2(x: List[bool]):
|
|
return sum(x)
|
|
|
|
self.checkScript(fn, ([1, 2, 3], ))
|
|
self.checkScript(fn1, ([1.0, 2.0, 3.0], ))
|
|
self.checkScript(fn1, ([1, 2.8, 3], ))
|
|
self.checkScript(fn2, ([True, False, False], ))
|
|
self.checkScript(fn2, ([False, False, False], ))
|
|
self.checkScript(fn2, ([0, 1, 1, 0], ))
|
|
|
|
def test_list_unification(self):
|
|
def fn():
|
|
return [1, None, 2]
|
|
|
|
def fn2(x):
|
|
return [torch.ones(2, 2), None, x]
|
|
|
|
self.checkScript(fn, [])
|
|
self.checkScript(fn2, (torch.ones(2, 2),))
|
|
|
|
# to avoid defining sum_list in multiple tests
|
|
def get_sum_list_fn(self):
|
|
def sum_list(a):
|
|
# type: (List[int]) -> int
|
|
sum = 0
|
|
for i in a:
|
|
sum += i
|
|
|
|
return sum
|
|
|
|
return sum_list
|
|
|
|
def test_sum_list_diff_elms(self):
|
|
self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
|
|
|
|
def test_sum_list_empty(self):
|
|
self.checkScript(self.get_sum_list_fn(), ([],))
|
|
|
|
def test_sum_list_one(self):
|
|
self.checkScript(self.get_sum_list_fn(), ([1],))
|
|
|
|
def test_sum_list_literal(self):
|
|
|
|
def sum_list():
|
|
# type: () -> int
|
|
sum = 0
|
|
for i in [1, 2, 3, 4, 5]:
|
|
sum += i
|
|
|
|
return sum
|
|
|
|
self.checkScript(sum_list, ())
|
|
|
|
def test_sum_list_wrong_type(self):
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
|
|
@torch.jit.script
|
|
def sum_list(a):
|
|
# type: (int) -> int
|
|
sum = 0
|
|
for i in a: # noqa: T484
|
|
sum += i
|
|
|
|
return sum
|
|
|
|
sum_list(1)
|
|
|
|
def test_list_iterables(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def list_iterables(x):
|
|
for i, j in [2, 3, 4], [5, 6, 7]:
|
|
x += i
|
|
x += j
|
|
return x
|
|
''')
|
|
|
|
def test_for_in_string(self):
|
|
def test_strings(x):
|
|
# type: (str) -> str
|
|
reverse = ""
|
|
for c in x:
|
|
reverse = c + reverse
|
|
return reverse
|
|
|
|
self.checkScript(test_strings, ("hello",))
|
|
self.checkScript(test_strings, ("",))
|
|
|
|
def test_list_strings(x):
|
|
# type: (List[str]) -> str
|
|
result = ""
|
|
for sub_str in x:
|
|
result += sub_str
|
|
return result
|
|
|
|
self.checkScript(test_list_strings, (["hello", "world"],))
|
|
self.checkScript(test_list_strings, (["hello", " ", "world", ""],))
|
|
|
|
def test_for_in_dict(self):
|
|
def test_dicts(x):
|
|
# type: (Dict[str, int]) -> int
|
|
sum = 0
|
|
for key in x:
|
|
sum += x[key]
|
|
return sum
|
|
|
|
self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
|
|
|
|
def test_dict_keys_values(x):
|
|
# type: (Dict[str, int]) -> Tuple[str, int]
|
|
key_str = ""
|
|
sum = 0
|
|
for key in x.keys():
|
|
key_str += key
|
|
for val in x.values():
|
|
sum += val
|
|
return key_str, sum
|
|
|
|
self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
|
|
|
|
def test_for_tuple_unpack(self):
|
|
def for_tuple_unpack(x, y):
|
|
for i, j in [[3, 4], [5, 6], [7, 8]]:
|
|
x += i
|
|
y += j
|
|
return x, y
|
|
|
|
self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))
|
|
|
|
def nested_tuple_unpack(x, y):
|
|
# type: (List[int], List[int]) -> int
|
|
sum = 0
|
|
for i, (j, k), v in zip(x, enumerate(x), y):
|
|
sum += i + j + k + v
|
|
return sum
|
|
|
|
self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))
|
|
|
|
def test_dict_comprehension(self):
|
|
def fn():
|
|
return {i : chr(i + 65) for i in range(4)}
|
|
self.checkScript(fn, ())
|
|
|
|
def test_dict_comprehension_with_type_annotation(self):
|
|
def fn():
|
|
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
|
|
return d
|
|
self.checkScript(fn, ())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, ""):
|
|
with self.assertRaisesRegex(AssertionError, "Expected Dict "
|
|
"type annotation for dict "
|
|
"comprehension, found "
|
|
"Tuple[int, str]"):
|
|
@torch.jit.script
|
|
def fn():
|
|
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
|
|
return d
|
|
|
|
def test_dict_comprehension_scope(self):
|
|
def comprehension_can_access_outer_scope_variables():
|
|
lst = ["foo", "bar", "baz"]
|
|
return {l : len(l) for l in lst}
|
|
|
|
self.checkScript(comprehension_can_access_outer_scope_variables, ())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value i"):
|
|
@torch.jit.script
|
|
def outer_scope_cannot_access_comprehension_variables():
|
|
d = {i : chr(i + 65) for i in range(4)}
|
|
i = i + 1
|
|
|
|
def test_for_tuple_assign(self):
|
|
def test_simple_assign(x):
|
|
# type: (Tuple[int, float]) -> float
|
|
sum = 0.0
|
|
for a in x:
|
|
sum += float(a)
|
|
return sum
|
|
|
|
self.checkScript(test_simple_assign, ((1, 2.5),))
|
|
|
|
def test_tuple_assign(x):
|
|
# type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int
|
|
sum = 0
|
|
for a in x:
|
|
sum += a[0]
|
|
sum += a[1]
|
|
return sum
|
|
|
|
self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))
|
|
|
|
def test_single_starred_lhs(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
|
|
' of another non-starred expression'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def single_starred_lhs(x):
|
|
a = (x, x, x)
|
|
*b, = a
|
|
return b
|
|
''')
|
|
|
|
def test_singleton_tuple_unpack(self):
|
|
def foo(a):
|
|
b, = (a,)
|
|
return b + 1
|
|
self.checkScript(foo, (torch.rand(3),))
|
|
|
|
def test_tuple_assignments(self):
|
|
def var_tuple_assign(x, y):
|
|
# type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
|
|
(a, b), c = x, y
|
|
return a + b + c
|
|
|
|
tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
|
|
self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))
|
|
|
|
def nested_tuple_assign(x, y, z):
|
|
# type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
|
|
a, (b, (c, d)), (e, f) = x, y, z
|
|
return a + b + c + d + e + f
|
|
|
|
self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))
|
|
|
|
def subscript_tuple_assign(a, x, i):
|
|
# type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
|
|
a[i], (x[i], b) = 1, (2, 3)
|
|
return a[i] + 1, x + 5, b
|
|
|
|
self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))
|
|
|
|
def star_tuple_assign():
|
|
# type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
|
|
a, (b, *c), *d = 1, (2, 3, 4), 5, 6
|
|
return a, b, c, d
|
|
|
|
self.checkScript(star_tuple_assign, ())
|
|
|
|
def subscript_tuple_augmented_assign(a):
|
|
# type: (Tuple[int, int]) -> Tuple[int, int]
|
|
a[0] += 1
|
|
return a
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
|
|
scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
|
|
|
|
def test_multiple_assign(self):
|
|
def test():
|
|
a = b, c = d, f = (1, 1)
|
|
|
|
# side effect
|
|
ten = torch.tensor(1)
|
|
ten1 = ten2 = ten.add_(1)
|
|
|
|
# ordering
|
|
x = 1
|
|
y = 3
|
|
x, y = y, x + y
|
|
|
|
return a, b, c, d, f, ten, ten1, ten2, x, y
|
|
|
|
self.checkScript(test, ())
|
|
|
|
def test_opt_opt_refinement(self):
|
|
@torch.jit.script
|
|
def test_unify(weight, bias):
|
|
# type: (Optional[int], Optional[int]) -> Optional[int]
|
|
if weight is not None:
|
|
opt = None
|
|
else:
|
|
if bias is not None:
|
|
opt = 1
|
|
else:
|
|
opt = None
|
|
|
|
return opt
|
|
|
|
def test_optional_refinement(self):
|
|
@torch.jit.script
|
|
def test_if_none_assignment(x):
|
|
# type: (Optional[int]) -> int
|
|
if x is None:
|
|
x = 1
|
|
return x + 1
|
|
|
|
self.assertEqual(test_if_none_assignment(1), 2)
|
|
|
|
def test_optional_conversion(self):
|
|
@torch.jit.script
|
|
def other_fn(x=None):
|
|
# type: (Optional[int]) -> int
|
|
return torch.jit._unwrap_optional(x)
|
|
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
# type: (int) -> int
|
|
return other_fn(x)
|
|
|
|
self.assertEqual(fn(2), 2)
|
|
|
|
@torch.jit.script
|
|
def unify_to_optional(x):
|
|
# type: (bool) -> Optional[int]
|
|
if x:
|
|
a = None
|
|
else:
|
|
a = 2
|
|
return a
|
|
|
|
self.assertEqual(unify_to_optional(True), None)
|
|
self.assertEqual(unify_to_optional(False), 2)
|
|
|
|
@torch.jit.script
|
|
def opt_list(x):
|
|
# type: (Optional[List[float]]) -> int
|
|
return 2
|
|
|
|
@torch.jit.script
|
|
def broadcast_opt_list(x):
|
|
# type: (Optional[BroadcastingList2[float]]) -> int
|
|
return 2
|
|
|
|
@torch.jit.script
|
|
def opt_list_tuple_caller(x):
|
|
# type: (Tuple[float, float]) -> int
|
|
return opt_list(x) + broadcast_opt_list(x)
|
|
|
|
self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
|
|
|
|
def test_optional_tuple(self):
|
|
def fn(x=None):
|
|
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
|
|
if x is None:
|
|
new_x = (1, 2)
|
|
else:
|
|
new_x = x
|
|
return new_x
|
|
|
|
self.checkScript(fn, ((3, 4),))
|
|
self.checkScript(fn, ())
|
|
|
|
def test_namedtuple_redefine(self):
|
|
global _1, _2
|
|
_1 = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
|
|
_2 = namedtuple('GoogLeNetOutputs', ['different'])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r'redefine'):
|
|
@torch.jit.script
|
|
def foo(x, y):
|
|
# type: (_1, _2) -> _1
|
|
return x
|
|
|
|
def test_namedtuple_py2(self):
|
|
global _GoogLeNetOutputs # see [local resolution in python]
|
|
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
|
|
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
|
|
return x
|
|
|
|
vals = torch.rand(3), torch.rand(4), torch.rand(5)
|
|
out = foo(_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2]))
|
|
self.assertEqual(out.logits, vals[0])
|
|
self.assertEqual(out.aux_logits2, vals[1])
|
|
self.assertEqual(out.aux_logits1, vals[2])
|
|
|
|
def test_namedtuple_good_error(self):
|
|
global _GoogLeNetOutputs # see [local resolution in python]
|
|
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
|
|
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
|
|
return x
|
|
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r'aka NamedTuple\(logits, aux_logits2, aux_logits1\)'):
|
|
out = foo(_GoogLeNetOutputs(logits=3, aux_logits2=4, aux_logits1=5))
|