Revert D25717504: Clean up some type annotations in test/jit

Test Plan: revert-hammer

Differential Revision:
D25717504 (a4f30d48d8)

Original commit changeset: 9a83c44db02e

fbshipit-source-id: e6e3a83bed22701d8125f5a293dfcd5093c1a2cd
This commit is contained in:
Heitor Schueroff
2021-01-08 12:13:00 -08:00
committed by Facebook GitHub Bot
parent f9f758e349
commit 1bb7d8ff93
8 changed files with 389 additions and 197 deletions

View File

@ -41,7 +41,8 @@ class TestAsync(JitTestCase):
def test_async_parsing(self):
@torch.jit.script
def foo(x: Tensor) -> List[Tensor]:
def foo(x):
# type: (Tensor) -> List[Tensor]
return [torch.neg(x), x.t()]
@torch.jit.script
@ -256,7 +257,8 @@ class TestAsync(JitTestCase):
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
@torch.jit.script_method
def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
def forward(self, x):
# type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
future1 = torch.jit._fork(self.traced, x)
future2 = torch.jit._fork(torch.neg, x)

View File

@ -2,7 +2,7 @@ import os
import sys
import inspect
import unittest
from typing import Dict, List
from typing import List
import torch
@ -78,7 +78,8 @@ class TestBuiltins(JitTestCase):
torch.jit.script(Mod())
def test_del(self):
def fn(x: List[int]) -> List[int]:
def fn(x):
# type: (List[int]) -> List[int]
a = x * 2
del a
return x
@ -108,14 +109,16 @@ class TestBuiltins(JitTestCase):
return a
def test_del_multiple_operands(self):
def fn(x: List[int]) -> List[int]:
def fn(x):
# type: (List[int]) -> List[int]
a, b, c = x[0], x[1], x[2]
del a, b, c
return x
self.checkScript(fn, ([1, 2, 3],))
def del_list_multiple_operands(x: List[int]) -> List[int]:
def del_list_multiple_operands(x):
# type: (List[int]) -> List[int]
del x[0], x[1]
return x
@ -123,7 +126,8 @@ class TestBuiltins(JitTestCase):
jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2])
self.assertEquals(py_out, jit_out)
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
def del_dict_multiple_operands(x):
# type: (Dict[str, int]) -> Dict[str, int]
del x['hi'], x['there']
return x

View File

@ -20,19 +20,22 @@ if __name__ == '__main__':
class TestList(JitTestCase):
def test_in_check(self):
def int_in(x: List[int]) -> bool:
def int_in(x):
# type: (List[int]) -> bool
return 2 in x
self.checkScript(int_in, ([1, 2, 3],))
self.checkScript(int_in, ([1, 3, 3],))
def float_in(x: List[float]) -> bool:
def float_in(x):
# type: (List[float]) -> bool
return 2. in x
self.checkScript(float_in, ([1., 2., 3.],))
self.checkScript(float_in, ([1., 3., 3.],))
def str_in(x: List[str]) -> bool:
def str_in(x):
# type: (List[str]) -> bool
return 'hi' in x
self.checkScript(str_in, (['not', 'here'],))
@ -97,7 +100,8 @@ class TestList(JitTestCase):
def inputs():
return [1, 2, 3, 4]
def fn(x: List[int]) -> List[int]:
def fn(x):
# type: (List[int]) -> List[int]
del x[1]
return x
@ -110,7 +114,8 @@ class TestList(JitTestCase):
self.assertEqual(torch.jit.script(fn)(inputs()), python_out)
@torch.jit.script
def fn2(x: List[int]) -> List[int]:
def fn2(x):
# type: (List[int]) -> List[int]
del x[100]
return x
@ -119,7 +124,8 @@ class TestList(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "deletion at a single index"):
@torch.jit.script
def fn(x: List[int]) -> List[int]:
def fn(x):
# type: (List[int]) -> List[int]
del x[1:3]
return x
@ -143,19 +149,23 @@ class TestList(JitTestCase):
FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)
def test_min_bool_list(self):
def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]:
def jit_min_list(a, b):
# type: (List[bool], List[bool]) -> List[bool]
return min(a, b)
self.checkScript(jit_min_list, ([True, False], [False, True]))
def test_min_max_list(self):
def jit_min_list(a: List[int], b: List[int]) -> List[int]:
def jit_min_list(a, b):
# type: (List[int], List[int]) -> List[int]
return min(a, b)
def jit_min_list_float(a: List[float], b: List[float]) -> List[float]:
def jit_min_list_float(a, b):
# type: (List[float], List[float]) -> List[float]
return min(a, b)
def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
def jit_min_list_bool(a, b):
# type: (List[bool], List[bool]) -> List[bool]
return min(a, b)
def run_tests(func, a, b):
@ -176,13 +186,16 @@ class TestList(JitTestCase):
[False, True], [False, False, True], [False, False, False]]
run_tests(jit_min_list_bool, args_left_bool, args_right_bool)
def jit_max_list(a: List[int], b: List[int]) -> List[int]:
def jit_max_list(a, b):
# type: (List[int], List[int]) -> List[int]
return max(a, b)
def jit_max_list_float(a: List[float], b: List[float]) -> List[float]:
def jit_max_list_float(a, b):
# type: (List[float], List[float]) -> List[float]
return max(a, b)
def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
def jit_max_list_bool(a, b):
# type: (List[bool], List[bool]) -> List[bool]
return max(a, b)
args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]]
@ -352,7 +365,8 @@ class TestList(JitTestCase):
t2 = scope['func']()
self.assertEqual(t1, t2)
def test_fail(x: List[Tensor]) -> List[Tensor]:
def test_fail(x):
# type: (List[Tensor]) -> List[Tensor]
x.sort()
return x
@ -458,7 +472,8 @@ class TestList(JitTestCase):
self.checkScript(test_append, ())
def test_comprehensions_basic(self):
def comp(l: List[int]) -> List[int]:
def comp(l):
# type: (List[int]) -> List[int]
n = [x * 3 for x in l]
return n
@ -467,7 +482,8 @@ class TestList(JitTestCase):
self.checkScript(comp, ([1, 2, 3],))
def test_comprehensions_basic_float(self):
def comp(l: List[float]) -> List[float]:
def comp(l):
# type: (List[float]) -> List[float]
n = [x * 3 for x in l]
return n
@ -476,7 +492,8 @@ class TestList(JitTestCase):
def test_comprehensions_two_comps(self):
@torch.jit.script
def comp(l1: List[int], l2: List[int]) -> List[int]:
def comp(l1, l2):
# type: (List[int], List[int]) -> List[int]
n = [x * 3 for x in l1]
n2 = [x + 2 for x in l2]
@ -485,7 +502,8 @@ class TestList(JitTestCase):
self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7])
def test_comprehension_out_type_not_in_type(self):
def list_cast() -> int:
def list_cast():
# type: () -> int
li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]]
return li[0] + li[1] + li[2]
@ -495,13 +513,15 @@ class TestList(JitTestCase):
def test_func(fn, inputs):
self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs))
def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]:
def foo(names, results):
# type: (List[int], List[int]) -> List[Tuple[int, int]]
return [(k + 5, v - 2) for k, v in zip(names, results)]
test_func(foo, ([1, 2, 4], [4, 7, 9]))
test_func(foo, ([5], [4, 7, 9]))
def fn(x: int) -> List[int]:
def fn(x):
# type: (int) -> List[int]
return [i for i in range(x)] # noqa: C416
test_func(fn, (9,))
@ -581,7 +601,8 @@ class TestList(JitTestCase):
def test_mutable_list_function_inline(self):
@torch.jit.script
def bar(y: List[int]) -> None:
def bar(y):
# type: (List[int]) -> None
y.append(4)
@torch.jit.script
@ -867,7 +888,8 @@ class TestList(JitTestCase):
def test_extend_list_mutable(self):
@torch.jit.script
def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]:
def extend_list(a, b):
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
a.extend(b)
return a
@ -878,7 +900,8 @@ class TestList(JitTestCase):
def test_extend_list_immutable(self):
@torch.jit.script
def extend_list(a: List[int], b: List[int]) -> List[int]:
def extend_list(a, b):
# type: (List[int], List[int]) -> List[int]
a.extend(b)
return a
@ -889,7 +912,8 @@ class TestList(JitTestCase):
def test_copy_list_mutable(self):
@torch.jit.script
def copy_list(a: List[Tensor]) -> List[Tensor]:
def copy_list(a):
# type: (List[Tensor]) -> List[Tensor]
return a.copy()
for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
@ -897,29 +921,36 @@ class TestList(JitTestCase):
def test_copy_list_immutable(self):
@torch.jit.script
def copy_list(a: List[int]) -> List[int]:
def copy_list(a):
# type: (List[int]) -> List[int]
return a.copy()
for l in [[], [1], [1, 2, 3]]:
self.assertEqual(copy_list(l), l)
def test_min_max_single_list(self):
def min_intlist(li: List[int]) -> int:
def min_intlist(li):
# type: (List[int]) -> int
return min(li)
def max_intlist(li: List[int]) -> int:
def max_intlist(li):
# type: (List[int]) -> int
return max(li)
def min_boollist(li: List[bool]) -> bool:
def min_boollist(li):
# type: (List[bool]) -> bool
return min(li)
def max_boollist(li: List[bool]) -> bool:
def max_boollist(li):
# type: (List[bool]) -> bool
return max(li)
def min_floatlist(li: List[float]) -> float:
def min_floatlist(li):
# type: (List[float]) -> float
return min(li)
def max_floatlist(li: List[float]) -> float:
def max_floatlist(li):
# type: (List[float]) -> float
return max(li)
@ -949,19 +980,23 @@ class TestList(JitTestCase):
"""
Boolean dtype unit tests.
"""
def to_list_bool_0D(x: torch.Tensor) -> bool:
def to_list_bool_0D(x):
# type: (torch.Tensor) -> bool
li = torch.jit.annotate(bool, x.tolist())
return li
def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
def to_list_bool_1D(x):
# type: (torch.Tensor) -> List[bool]
li = torch.jit.annotate(List[bool], x.tolist())
return li
def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]:
def to_list_bool_2D(x):
# type: (torch.Tensor) -> List[List[bool]]
li = torch.jit.annotate(List[List[bool]], x.tolist())
return li
def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]:
def to_list_bool_3D(x):
# type: (torch.Tensor) -> List[List[List[bool]]]
li = torch.jit.annotate(List[List[List[bool]]], x.tolist())
return li
@ -986,19 +1021,23 @@ class TestList(JitTestCase):
"""
Int dtype unit tests.
"""
def to_list_int_0D(x: torch.Tensor) -> int:
def to_list_int_0D(x):
# type: (torch.Tensor) -> int
li = torch.jit.annotate(int, x.tolist())
return li
def to_list_int_1D(x: torch.Tensor) -> List[int]:
def to_list_int_1D(x):
# type: (torch.Tensor) -> List[int]
li = torch.jit.annotate(List[int], x.tolist())
return li
def to_list_int_2D(x: torch.Tensor) -> List[List[int]]:
def to_list_int_2D(x):
# type: (torch.Tensor) -> List[List[int]]
li = torch.jit.annotate(List[List[int]], x.tolist())
return li
def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]:
def to_list_int_3D(x):
# type: (torch.Tensor) -> List[List[List[int]]]
li = torch.jit.annotate(List[List[List[int]]], x.tolist())
return li
@ -1019,19 +1058,23 @@ class TestList(JitTestCase):
"""
Float dtype unit tests.
"""
def to_list_float_0D(x: torch.Tensor) -> float:
def to_list_float_0D(x):
# type: (torch.Tensor) -> float
li = torch.jit.annotate(float, x.tolist())
return li
def to_list_float_1D(x: torch.Tensor) -> List[float]:
def to_list_float_1D(x):
# type: (torch.Tensor) -> List[float]
li = torch.jit.annotate(List[float], x.tolist())
return li
def to_list_float_2D(x: torch.Tensor) -> List[List[float]]:
def to_list_float_2D(x):
# type: (torch.Tensor) -> List[List[float]]
li = torch.jit.annotate(List[List[float]], x.tolist())
return li
def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]:
def to_list_float_3D(x):
# type: (torch.Tensor) -> List[List[List[float]]]
li = torch.jit.annotate(List[List[List[float]]], x.tolist())
return li
@ -1056,23 +1099,28 @@ class TestList(JitTestCase):
- type annotation with the wrong dimension
- type annotation with scalar type that doesn't match the input scalar type
"""
def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]:
def to_list_missing_type_annotation(x):
# type: (torch.Tensor) -> List[float]
li = x.tolist()
return li
def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]:
def to_list_incorrect_type_annotation(x):
# type: (torch.Tensor) -> List[float]
li = torch.jit.annotate(float, x.tolist())
return li
def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]:
def to_list_unsupported_type_annotation(x):
# type: (torch.Tensor) -> List[float]
li = torch.jit.annotate(List[str], x.tolist())
return li
def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]:
def to_list_type_annotation_wrong_dim(x):
# type: (torch.Tensor) -> List[List[float]]
li = torch.jit.annotate(List[List[float]], x.tolist())
return li
def to_list_type_annotation_incorrect_scalar_type(x: torch.Tensor) -> List[float]:
def to_list_type_annotation_incorrect_scalar_type(x):
# type: (torch.Tensor) -> List[float]
li = torch.jit.annotate(List[float], x.tolist())
return li
@ -1116,15 +1164,18 @@ class TestList(JitTestCase):
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
self.skipTest("CUDA is not available")
def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
def to_list_bool_1D(x):
# type: (torch.Tensor) -> List[bool]
li = torch.jit.annotate(List[bool], x.tolist())
return li
def to_list_int_1D(x: torch.Tensor) -> List[int]:
def to_list_int_1D(x):
# type: (torch.Tensor) -> List[int]
li = torch.jit.annotate(List[int], x.tolist())
return li
def to_list_float_1D(x: torch.Tensor) -> List[float]:
def to_list_float_1D(x):
# type: (torch.Tensor) -> List[float]
li = torch.jit.annotate(List[float], x.tolist())
return li
@ -1136,7 +1187,8 @@ class TestList(JitTestCase):
5, dtype=torch.double).cuda(),))
def test_no_element_type_annotation(self):
def fn_with_comment(x: torch.Tensor) -> List:
def fn_with_comment(x):
# type: (torch.Tensor) -> List
a: List = x.tolist()
return a
@ -1178,7 +1230,8 @@ class TestDict(JitTestCase):
def inputs():
return {'hi': 2, 'bye': 3}
def fn(x: Dict[str, int]) -> Dict[str, int]:
def fn(x):
# type: (Dict[str, int]) -> Dict[str, int]
del x['hi']
return x
@ -1194,7 +1247,8 @@ class TestDict(JitTestCase):
def test_keys(self):
@torch.jit.script
def keys(x: Dict[str, Tensor]) -> List[str]:
def keys(x):
# type: (Dict[str, Tensor]) -> List[str]
return list(x.keys())
self.assertEqual(set(keys(self.dict())), set(self.dict().keys()))
@ -1209,26 +1263,30 @@ class TestDict(JitTestCase):
def test_values(self):
@torch.jit.script
def values(x: Dict[str, Tensor]) -> List[Tensor]:
def values(x):
# type: (Dict[str, Tensor]) -> List[Tensor]
return list(x.values())
the_dict = self.dict()
self.assertEqual(set(values(the_dict)), set(the_dict.values()))
def test_len(self):
def length(x: Dict[str, Tensor]) -> int:
def length(x):
# type: (Dict[str, Tensor]) -> int
return len(x)
self.checkScript(length, (self.dict(),))
def test_copy(self):
def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
def func(x):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
return x.copy()
self.checkScript(func, (self.dict(),))
def test_items(self):
def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]:
def func(x):
# type: (Dict[str, Tensor]) -> List[Tuple[str, Tensor]]
return x.items()
# The value returned by Python is in arbitrary order, so we can't use
@ -1243,7 +1301,8 @@ class TestDict(JitTestCase):
self.assertTrue(item in script_out)
def test_pop(self):
def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]:
def pop(x, key):
# type: (Dict[str, Tensor], str) -> Tuple[Tensor, Dict[str, Tensor]]
return x.pop(key), x
# checkScript doesn't copy the inputs, so we can't use it since this mutates
@ -1259,14 +1318,16 @@ class TestDict(JitTestCase):
torch.jit.script(pop)(self.dict(), 'x')
def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
def default_pop(x, key, default):
# type: (Dict[str, Tensor], str, Tensor) -> Tuple[Tensor, Dict[str, Tensor]]
return x.pop(key, default), x
tester(default_pop, 'a', torch.randn(2, 2))
tester(default_pop, 'x', torch.randn(2, 2))
def test_setdefault(self):
def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]:
def setdefault(x, key, default):
# type: (Dict[str, Tensor], str, Tensor) -> Dict[str, Tensor]
x.setdefault(key, default)
return x
@ -1274,7 +1335,8 @@ class TestDict(JitTestCase):
self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2)))
def test_update(self):
def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
def update(a, b):
# type: (Dict[str, Tensor], Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]
a.update(b)
return a, b
@ -1291,7 +1353,8 @@ class TestDict(JitTestCase):
self.checkScript(foo, ())
def test_aug_assign(self):
def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]:
def aug_assign_dict_tensor(a):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
a['a'] += 1
a['b'] -= 12
a['c'] *= 122
@ -1299,7 +1362,8 @@ class TestDict(JitTestCase):
a['c'] %= 2
return a
def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]:
def aug_assign_dict_prim(a):
# type: (Dict[str, float]) -> Dict[str, float]
a['a'] += 3.4
a['b'] -= 2.4
a['c'] *= 3.0
@ -1312,7 +1376,8 @@ class TestDict(JitTestCase):
def test_popitem(self):
@torch.jit.script
def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]:
def popitem(x):
# type: (Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]
item = x.popitem()
return item, x
@ -1330,56 +1395,65 @@ class TestDict(JitTestCase):
self.assertTrue(isinstance(script_out[0][1], torch.Tensor))
def test_clear(self):
def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
def clear(x):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
x.clear()
return x
self.checkScript(clear, (self.dict(),))
def test_get(self):
def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
def get(x, key):
# type: (Dict[str, Tensor], str) -> Optional[Tensor]
return x.get(key)
self.checkScript(get, (self.dict(), 'a'))
self.checkScript(get, (self.dict(), "doesn't exist"))
def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
def get_default(x, key):
# type: (Dict[str, Tensor], str) -> Optional[Tensor]
return x.get(key, torch.randn(2, 2))
self.checkScript(get, (self.dict(), 'a'))
self.checkScript(get, (self.dict(), "doesn't exist"))
def test_get_boolkey(self):
def get(x: Dict[bool, int], key: bool) -> Optional[int]:
def get(x, key):
# type: (Dict[bool, int], bool) -> Optional[int]
return x.get(key)
self.checkScript(get, (self.dict_bool(), True))
self.checkScript(get, (self.dict_bool(), False))
def get_default(x: Dict[bool, int], key: bool) -> int:
def get_default(x, key):
# type: (Dict[bool, int], bool) -> int
return x.get(key, 42)
self.checkScript(get_default, (self.dict_bool(), True))
self.checkScript(get_default, (self.dict_bool(), False))
def test_basic(self):
def simple(x: Dict[str, int]) -> Dict[str, int]:
def simple(x):
# type: (Dict[str, int]) -> Dict[str, int]
return x
self.checkScript(simple, ({'item': 20, 'other_item': 120},))
def index(x: Dict[str, int]) -> int:
def index(x):
# type: (Dict[str, int]) -> int
return x['item']
self.checkScript(index, ({'item': 20, 'other_item': 120},))
def type_default() -> Dict[str, Tensor]:
def type_default():
# type: () -> Dict[str, Tensor]
return {}
self.checkScript(type_default, ())
@torch.jit.script
def missing_index(x: Dict[str, int]) -> int:
def missing_index(x):
# type: (Dict[str, int]) -> int
return x['dne']
with self.assertRaisesRegex(RuntimeError, "KeyError"):
@ -1401,14 +1475,16 @@ class TestDict(JitTestCase):
'''))
self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3())
def list_of_dicts() -> List[Dict[str, Tensor]]:
def list_of_dicts():
# type: () -> List[Dict[str, Tensor]]
return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}]
self.checkScript(list_of_dicts, ())
def test_mutability(self):
@torch.jit.script
def fn() -> Dict[str, int]:
def fn():
# type: () -> Dict[str, int]
a = torch.jit.annotate(Dict[str, int], {})
a['ok'] = 10
return a
@ -1418,12 +1494,14 @@ class TestDict(JitTestCase):
def test_key_type(self):
with self.assertRaisesRegex(RuntimeError, "but instead found type"):
@torch.jit.script
def fn(a: Dict[str, int]) -> int:
def fn(a):
# type: (Dict[str, int]) -> int
return a[None]
def test_loop(self):
@torch.jit.script
def fn(x: int) -> Dict[str, int]:
def fn(x):
# type: (int) -> Dict[str, int]
a = torch.jit.annotate(Dict[str, int], {})
for i in range(x):
a['ok'] = i
@ -1442,14 +1520,16 @@ class TestDict(JitTestCase):
self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
def test_membership(self):
def fn(x: Dict[int, int], y: int) -> int:
def fn(x, y):
# type: (Dict[int, int], int) -> int
return x.get(y, 3)
d = {1: 2, 3: 4}
self.checkScript(fn, (d, 3))
self.checkScript(fn, (d, 2))
def optional(x: Dict[int, int], y: int) -> bool:
def optional(x, y):
# type: (Dict[int, int], int) -> bool
res = x.get(y)
return res is None
@ -1458,15 +1538,18 @@ class TestDict(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"):
@torch.jit.script
def bad_types(x: Dict[int, int], y: int) -> int:
def bad_types(x, y):
# type: (Dict[int, int], int) -> int
return x.get(y) # noqa: T484
def test_dict_to_python(self):
@torch.jit.ignore
def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
def python_lookup(my_dict, keys):
# type: (Dict[str, int], List[str]) -> List[int]
return [my_dict[k] for k in keys]
def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
def fn(my_dict, keys):
# type: (Dict[str, int], List[str]) -> List[int]
return python_lookup(my_dict, keys)
a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
@ -1518,7 +1601,8 @@ class TestDict(JitTestCase):
key and value types produces an error.
"""
# This function uses a type comment.
def fn_with_comment(input: Dict) -> Any:
def fn_with_comment(input):
# type: (Dict) -> Any
return input
# This function uses Python3 style type annotations.

View File

@ -6,7 +6,6 @@ import torch
import torch.nn as nn
import os
import sys
from torch import Tensor
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
@ -23,30 +22,36 @@ class OrigModule(nn.Module):
def __init__(self):
super(OrigModule, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 + inp2 + 1
def two(self, input: Tensor) -> Tensor:
def two(self, input):
# type: (Tensor) -> Tensor
return input + 2
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return input + self.one(input, input) + 1
class NewModule(nn.Module):
def __init__(self):
super(NewModule, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.one(input, input + 1)
class TestModuleInterface(JitTestCase):
def test_not_submodule_interface_call(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
class TestNotModuleInterfaceCall(nn.Module):
@ -56,7 +61,8 @@ class TestModuleInterface(JitTestCase):
super(TestNotModuleInterfaceCall, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.two(input)
with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"):
@ -66,51 +72,64 @@ class TestModuleInterface(JitTestCase):
global OneTwoModule, OneTwoClass
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x: Tensor) -> Tensor:
def two(self, x):
# type: (Tensor) -> Tensor
pass
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.interface
class OneTwoClass(object):
def one(self, x: Tensor, y: Tensor) -> Tensor:
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x: Tensor) -> Tensor:
def two(self, x):
# type: (Tensor) -> Tensor
pass
class FooMod(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x + y
def two(self, x: Tensor) -> Tensor:
def two(self, x):
# type: (Tensor) -> Tensor
return 2 * x
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
return self.one(self.two(x), x)
class BarMod(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x * y
def two(self, x: Tensor) -> Tensor:
def two(self, x):
# type: (Tensor) -> Tensor
return 2 / x
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
return self.two(self.one(x, x))
@torch.jit.export
def forward2(self, x: Tensor) -> Tensor:
def forward2(self, x):
# type: (Tensor) -> Tensor
return self.two(self.one(x, x)) + 1
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
return mod_list[0].forward(x) + mod_list[1].forward(x)
def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor:
def use_class_interface(mod_list, x):
# type: (List[OneTwoClass], Tensor) -> Tensor
return mod_list[0].two(x) + mod_list[1].one(x, x)
scripted_foo_mod = torch.jit.script(FooMod())
@ -120,7 +139,8 @@ class TestModuleInterface(JitTestCase):
self.checkScript(use_class_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor:
def call_module_interface_on_other_method(mod_interface, x):
# type: (OneTwoModule, Tensor) -> Tensor
return mod_interface.forward2(x)
# ensure error out when we call the module on the method other than the interface specified.
@ -132,28 +152,35 @@ class TestModuleInterface(JitTestCase):
global OneTwoModule
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x: Tensor) -> Tensor:
def two(self, x):
# type: (Tensor) -> Tensor
pass
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.script
def as_module_interface(x: OneTwoModule) -> OneTwoModule:
def as_module_interface(x):
# type: (OneTwoModule) -> OneTwoModule
return x
@torch.jit.script
class Foo(object):
def one(self, x: Tensor, y: Tensor) -> Tensor:
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x + y
def two(self, x: Tensor) -> Tensor:
def two(self, x):
# type: (Tensor) -> Tensor
return 2 * x
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
return self.one(self.two(x), x)
# check class object is not a subtype of module interface
@ -161,10 +188,12 @@ class TestModuleInterface(JitTestCase):
as_module_interface(Foo())
class WrongMod(nn.Module):
def two(self, x: int) -> int:
def two(self, x):
# type: (int) -> int
return 2 * x
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
return x + torch.randn(3, self.two(3))
scripted_wrong_mod = torch.jit.script(WrongMod())
@ -215,16 +244,19 @@ class TestModuleInterface(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
@torch.jit.interface
class InheritMod(nn.ReLU):
def three(self, x: Tensor) -> Tensor:
def three(self, x):
# type: (Tensor) -> Tensor
return 3 * x
def test_module_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class TestModule(nn.Module):
@ -234,7 +266,8 @@ class TestModuleInterface(JitTestCase):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
@ -252,17 +285,20 @@ class TestModuleInterface(JitTestCase):
def test_module_swap_wrong_module(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class NewModuleWrong(nn.Module):
def __init__(self):
super(NewModuleWrong, self).__init__()
def forward(self, input: int) -> int:
def forward(self, input):
# type: (int) -> int
return input + 1
class TestModule(nn.Module):
@ -272,7 +308,8 @@ class TestModuleInterface(JitTestCase):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
@ -283,10 +320,12 @@ class TestModuleInterface(JitTestCase):
def test_module_swap_no_lazy_compile(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class TestModule(nn.Module):
@ -296,17 +335,20 @@ class TestModuleInterface(JitTestCase):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
class NewModuleMethodNotLazyCompile(nn.Module):
def __init__(self):
super(NewModuleMethodNotLazyCompile, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return input + 1
scripted_mod = torch.jit.script(TestModule())
@ -320,10 +362,12 @@ class TestModuleInterface(JitTestCase):
super(NewModuleMethodManualExport, self).__init__()
@torch.jit.export
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return input + 1
scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport())
@ -337,7 +381,8 @@ class TestModuleInterface(JitTestCase):
super(TestNoModuleInterface, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod(input)
scripted_no_module_interface = torch.jit.script(TestNoModuleInterface())
@ -352,10 +397,12 @@ class TestModuleInterface(JitTestCase):
def test_script_module_as_interface_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class OrigScriptModule(torch.jit.ScriptModule):
@ -363,11 +410,13 @@ class TestModuleInterface(JitTestCase):
super(OrigScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 + inp2 + 1
@torch.jit.script_method
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return input + self.one(input, input) + 1
class NewScriptModule(torch.jit.ScriptModule):
@ -375,11 +424,13 @@ class TestModuleInterface(JitTestCase):
super(NewScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
@torch.jit.script_method
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.one(input, input + 1)
class TestNNModuleWithScriptModule(nn.Module):
@ -389,7 +440,8 @@ class TestModuleInterface(JitTestCase):
super(TestNNModuleWithScriptModule, self).__init__()
self.proxy_mod = OrigScriptModule()
def forward(self, input: Tensor) -> Tensor:
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
input = torch.randn(3, 4)
@ -420,7 +472,8 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> int:
def forward(self, x):
# type: (Tensor) -> int
pass
class TestModule(torch.nn.Module):
@ -467,7 +520,8 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> int:
def forward(self, x):
# type: (Tensor) -> int
pass
class TestModule(torch.nn.Module):
@ -510,7 +564,8 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class TestModule(torch.nn.Module):
@ -555,7 +610,8 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class TestModule(torch.nn.Module):
@ -597,7 +653,8 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class TestModule(torch.nn.Module):
@ -631,7 +688,8 @@ class TestModuleInterface(JitTestCase):
def test_module_apis_interface(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
class TestModule(nn.Module):

View File

@ -284,7 +284,8 @@ class TestRecursiveScript(JitTestCase):
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
def test_class_compile(self):
def other_fn(a: int, b: Tensor) -> Tensor:
def other_fn(a, b):
# type: (int, Tensor) -> Tensor
return a * b
class B(object):
@ -306,7 +307,8 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(N(), (torch.randn(2, 2),))
def test_error_stack(self):
def d(x: int) -> int:
def d(x):
# type: (int) -> int
return x + 10
def c(x):
@ -329,7 +331,8 @@ class TestRecursiveScript(JitTestCase):
checker.run(str(e))
def test_error_stack_module(self):
def d(x: int) -> int:
def d(x):
# type: (int) -> int
return x + 10
def c(x):
@ -562,7 +565,8 @@ class TestRecursiveScript(JitTestCase):
self.a = 4
self.inner = Inner2()
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
def __setstate__(self, obj):
# type: (Tuple[int, Inner2]) -> None
a, inner = obj
self.a = a
self.inner = inner

View File

@ -680,7 +680,8 @@ class TestSaveLoad(JitTestCase):
"""
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
def bar(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.script
@ -710,7 +711,8 @@ class TestSaveLoad(JitTestCase):
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x: Tensor) -> Tensor:
def not_bar(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.script # noqa: F811
@ -765,7 +767,8 @@ class TestSaveLoad(JitTestCase):
@torch.jit.interface
class MyInterface(object):
def bar(self, x: Tensor) -> Tensor:
def bar(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.script
@ -806,7 +809,8 @@ class TestSaveLoad(JitTestCase):
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x: Tensor) -> Tensor:
def not_bar(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.script # noqa F811

View File

@ -25,7 +25,7 @@ from torch import Tensor
# Standard library
from collections import namedtuple
from itertools import chain
from typing import Dict, Optional
from typing import Dict
import warnings
if __name__ == '__main__':
@ -1862,11 +1862,13 @@ class TestTracer(JitTestCase):
class TestMixTracingScripting(JitTestCase):
def test_trace_script(self):
@torch.jit.script
def func1(x: Tuple[Tensor, Tensor]) -> Tensor:
def func1(x):
# type: (Tuple[Tensor, Tensor]) -> Tensor
return x[0] + x[1]
@torch.jit.script
def func2(x: List[Tensor]) -> Tensor:
def func2(x):
# type: (List[Tensor]) -> Tensor
return x[0] + x[1]
a = torch.randn(5)
@ -1876,7 +1878,8 @@ class TestMixTracingScripting(JitTestCase):
self.checkTrace(func2, ((a, b),))
@torch.jit.script
def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Tensor:
def func3(x, method='bilinear', align_corners=True):
# type: (Tensor, str, bool) -> Tensor
hw = x.shape[2:4]
return F.interpolate(x, hw, mode=method, align_corners=align_corners)
@ -1884,7 +1887,8 @@ class TestMixTracingScripting(JitTestCase):
self.checkTrace(func3, (inp,))
@torch.jit.script
def func4(x: Tensor, a: List[Optional[str]]) -> Tensor:
def func4(x, a):
# type: (Tensor, List[Optional[str]]) -> Tensor
if len(a) == 2:
return x + 2
else:

View File

@ -1,7 +1,7 @@
import os
import sys
from typing import Any, List
from typing import Any
import torch
from torch.testing._internal.jit_utils import JitTestCase
@ -50,7 +50,8 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
def test_basic(x: torch.Tensor) -> torch.Tensor:
def test_basic(x):
# type: (Tensor) -> Tensor
"""Basic test with one with-statement."""
c = Context(1)
@ -61,7 +62,8 @@ class TestWith(JitTestCase):
y *= c.count
return y
def test_pass(x: torch.Tensor) -> torch.Tensor:
def test_pass(x):
# type: (Tensor) -> Tensor
"""
Test with a pass statement inside a with-statement. Although
the body of the with is empty, __enter__ and __exit__ should
@ -75,7 +77,8 @@ class TestWith(JitTestCase):
x *= c.count
return x
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
def test_early_return(x, c):
# type: (Tensor, Context) -> Tensor
"""
Test that returning early from inside a with-statement works
as expected.
@ -87,7 +90,8 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
def test_conditional_early_return(x, c):
# type: (Tensor, Context) -> Tensor
"""
Test that conditionally returning early from inside a with-statement works
as expected.
@ -100,7 +104,8 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
def test_break(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
"""
Test that breaking early from inside a with-statement works
as expected.
@ -113,7 +118,8 @@ class TestWith(JitTestCase):
return x
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
def test_continue(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
"""
Test that using continue inside a with-statement works
as expected.
@ -126,7 +132,8 @@ class TestWith(JitTestCase):
return x
def test_serial(x: torch.Tensor) -> torch.Tensor:
def test_serial(x):
# type: (Tensor) -> Tensor
"""
Test two with-statements in a row.
"""
@ -140,7 +147,8 @@ class TestWith(JitTestCase):
return y
def test_nested(x: torch.Tensor) -> torch.Tensor:
def test_nested(x):
# type: (Tensor) -> Tensor
"""
Test nested with-statements.
"""
@ -154,7 +162,8 @@ class TestWith(JitTestCase):
return y
def test_combined(x: torch.Tensor) -> torch.Tensor:
def test_combined(x):
# type: (Tensor) -> Tensor
"""
Test a with-statement with multiple with items.
"""
@ -206,7 +215,8 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
def test_basic(x: torch.Tensor) -> torch.Tensor:
def test_basic(x):
# type: (Tensor) -> Tensor
"""Basic test with one with-statement."""
c = Context(1)
@ -217,7 +227,8 @@ class TestWith(JitTestCase):
y *= c.count
return y
def test_pass(x: torch.Tensor) -> torch.Tensor:
def test_pass(x):
# type: (Tensor) -> Tensor
"""
Test with a pass statement inside a with-statement. Although
the body of the with is empty, __enter__ and __exit__ should
@ -231,7 +242,8 @@ class TestWith(JitTestCase):
x *= c.count
return x
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
def test_early_return(x, c):
# type: (Tensor, Context) -> Tensor
"""
Test that returning early from inside a with-statement works
as expected.
@ -243,7 +255,8 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
def test_conditional_early_return(x, c):
# type: (Tensor, Context) -> Tensor
"""
Test that conditionally returning early from inside a with-statement works
as expected.
@ -256,7 +269,8 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
def test_break(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
"""
Test that breaking early from inside a with-statement works
as expected.
@ -269,7 +283,8 @@ class TestWith(JitTestCase):
return x
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
def test_continue(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
"""
Test that using continue inside a with-statement works
as expected.
@ -282,7 +297,8 @@ class TestWith(JitTestCase):
return x
def test_serial(x: torch.Tensor) -> torch.Tensor:
def test_serial(x):
# type: (Tensor) -> Tensor
"""
Test two with-statements in a row.
"""
@ -296,7 +312,8 @@ class TestWith(JitTestCase):
return y
def test_nested(x: torch.Tensor) -> torch.Tensor:
def test_nested(x):
# type: (Tensor) -> Tensor
"""
Test nested with-statements.
"""
@ -310,7 +327,8 @@ class TestWith(JitTestCase):
return y
def test_combined(x: torch.Tensor) -> torch.Tensor:
def test_combined(x):
# type: (Tensor) -> Tensor
"""
Test a with-statement with multiple with items.
"""
@ -363,11 +381,13 @@ class TestWith(JitTestCase):
self.count.sub_(0.3)
@torch.jit.script
def method_that_raises() -> torch.Tensor:
def method_that_raises():
# type: () -> Tensor
raise Exception("raised exception")
@torch.jit.script
def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
def test_exception(x, c):
# type: (Tensor, Context) -> Tensor
"""
Test the case in which an exception is thrown while executing the body of a with-statement.
"""
@ -377,7 +397,8 @@ class TestWith(JitTestCase):
return x
@torch.jit.script
def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
def test_exception_nested(x, c):
# type: (Tensor, Context) -> Tensor
"""
Test the case in which an exception is thrown while executing the body of a nested with-statement.
"""
@ -388,7 +409,8 @@ class TestWith(JitTestCase):
return x
@torch.jit.script
def with_that_raises(c: Context) -> torch.Tensor:
def with_that_raises(c):
# type: (Context) -> Tensor
a = torch.tensor([1])
with c as _:
@ -397,7 +419,8 @@ class TestWith(JitTestCase):
return a
@torch.jit.script
def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
def test_exception_fn_call(x, c):
# type: (Tensor, Context) -> Tensor
"""
Test the case in which an exception is thrown while there are active with-statements in two different
frames.
@ -483,25 +506,29 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: int, tb: int):
pass
def test_no_enter_no_exit(x: torch.Tensor, c: NoEnterNoExit) -> torch.Tensor:
def test_no_enter_no_exit(x, c):
# type: (Tensor, NoEnterNoExit) -> Tensor
with c as _:
pass
return x
def test_bad_enter(x: torch.Tensor, c: BadEnter) -> torch.Tensor:
def test_bad_enter(x, c):
# type: (Tensor, BadEnter) -> Tensor
with c as _:
pass
return x
def test_bad_exit(x: torch.Tensor, c: BadExit) -> torch.Tensor:
def test_bad_exit(x, c):
# type: (Tensor, BadExit) -> Tensor
with c as _:
pass
return x
def test_exit_incorrect_types(x: torch.Tensor, c: ExitIncorrectTypes) -> torch.Tensor:
def test_exit_incorrect_types(x, c):
# type: (Tensor, ExitIncorrectTypes) -> Tensor
with c as _:
pass
@ -538,7 +565,8 @@ class TestWith(JitTestCase):
"""
# Basic no_grad test.
def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def test_no_grad(x, y):
# type: (Tensor, Tensor) -> Tensor
with torch.no_grad():
w = x + y
@ -555,7 +583,8 @@ class TestWith(JitTestCase):
# Test assignment of a grad-less Tensor to a Tensor with gradients
# in a no_grad block.
def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def test_no_grad_assignment(x, y):
# type: (Tensor, Tensor) -> Tensor
with torch.no_grad():
x[0] = y
@ -574,11 +603,13 @@ class TestWith(JitTestCase):
super().__init__()
@torch.jit.ignore
def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def adder(self, x, y):
# type: (Tensor, Tensor) -> Tensor
w = x + y
return w
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def forward(self, x, y):
# type: (Tensor, Tensor) -> Tensor
with torch.no_grad():
w = self.adder(x, y)
@ -594,7 +625,8 @@ class TestWith(JitTestCase):
Check that torch.autograd.profiler.record_function context manager is
torchscriptable.
"""
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def with_rf(x, y):
# type: (Tensor, Tensor) -> Tensor
with torch.autograd.profiler.record_function("foo"):
# Nested record_function.
with torch.autograd.profiler.record_function("nested"):