Type annotations in test/jit (#50293)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50293

Switching to type annotations for improved safety and import tracking.

Test Plan: Sandcastle tests

Reviewed By: xush6528

Differential Revision: D25853949

fbshipit-source-id: fb873587bb521a0a55021ee4d34d1b05ea8f000d
This commit is contained in:
Richard Barnes
2021-01-12 16:45:16 -08:00
committed by Facebook GitHub Bot
parent 4c97ef8d77
commit 8c25b9701b
7 changed files with 206 additions and 390 deletions

View File

@ -5,7 +5,7 @@ import sys
import torch
import torch.nn as nn
from typing import Any
from typing import Any, Tuple
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -41,8 +41,7 @@ class TestAsync(JitTestCase):
def test_async_parsing(self):
@torch.jit.script
def foo(x):
# type: (Tensor) -> List[Tensor]
def foo(x: Tensor) -> List[Tensor]:
return [torch.neg(x), x.t()]
@torch.jit.script
@ -257,8 +256,7 @@ class TestAsync(JitTestCase):
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
@torch.jit.script_method
def forward(self, x):
# type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
future1 = torch.jit._fork(self.traced, x)
future2 = torch.jit._fork(torch.neg, x)

View File

@ -2,7 +2,7 @@ import os
import sys
import inspect
import unittest
from typing import List
from typing import Dict, List
import torch
@ -78,8 +78,7 @@ class TestBuiltins(JitTestCase):
torch.jit.script(Mod())
def test_del(self):
def fn(x):
# type: (List[int]) -> List[int]
def fn(x: List[int]) -> List[int]:
a = x * 2
del a
return x
@ -109,16 +108,14 @@ class TestBuiltins(JitTestCase):
return a
def test_del_multiple_operands(self):
def fn(x):
# type: (List[int]) -> List[int]
def fn(x: List[int]) -> List[int]:
a, b, c = x[0], x[1], x[2]
del a, b, c
return x
self.checkScript(fn, ([1, 2, 3],))
def del_list_multiple_operands(x):
# type: (List[int]) -> List[int]
def del_list_multiple_operands(x: List[int]) -> List[int]:
del x[0], x[1]
return x
@ -126,8 +123,7 @@ class TestBuiltins(JitTestCase):
jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2])
self.assertEquals(py_out, jit_out)
def del_dict_multiple_operands(x):
# type: (Dict[str, int]) -> Dict[str, int]
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
del x['hi'], x['there']
return x

View File

@ -1,10 +1,11 @@
import os
import sys
import inspect
from typing import Dict, List, Any
from typing import Any, Dict, List, Optional, Tuple
from textwrap import dedent
from collections import OrderedDict
from torch import Tensor
import torch
from torch.testing import FileCheck
@ -20,22 +21,19 @@ if __name__ == '__main__':
class TestList(JitTestCase):
def test_in_check(self):
def int_in(x):
# type: (List[int]) -> bool
def int_in(x: List[int]) -> bool:
return 2 in x
self.checkScript(int_in, ([1, 2, 3],))
self.checkScript(int_in, ([1, 3, 3],))
def float_in(x):
# type: (List[float]) -> bool
def float_in(x: List[float]) -> bool:
return 2. in x
self.checkScript(float_in, ([1., 2., 3.],))
self.checkScript(float_in, ([1., 3., 3.],))
def str_in(x):
# type: (List[str]) -> bool
def str_in(x: List[str]) -> bool:
return 'hi' in x
self.checkScript(str_in, (['not', 'here'],))
@ -100,8 +98,7 @@ class TestList(JitTestCase):
def inputs():
return [1, 2, 3, 4]
def fn(x):
# type: (List[int]) -> List[int]
def fn(x: List[int]) -> List[int]:
del x[1]
return x
@ -114,8 +111,7 @@ class TestList(JitTestCase):
self.assertEqual(torch.jit.script(fn)(inputs()), python_out)
@torch.jit.script
def fn2(x):
# type: (List[int]) -> List[int]
def fn2(x: List[int]) -> List[int]:
del x[100]
return x
@ -124,8 +120,7 @@ class TestList(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "deletion at a single index"):
@torch.jit.script
def fn(x):
# type: (List[int]) -> List[int]
def fn(x: List[int]) -> List[int]:
del x[1:3]
return x
@ -149,23 +144,19 @@ class TestList(JitTestCase):
FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)
def test_min_bool_list(self):
def jit_min_list(a, b):
# type: (List[bool], List[bool]) -> List[bool]
def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]:
return min(a, b)
self.checkScript(jit_min_list, ([True, False], [False, True]))
def test_min_max_list(self):
def jit_min_list(a, b):
# type: (List[int], List[int]) -> List[int]
def jit_min_list(a: List[int], b: List[int]) -> List[int]:
return min(a, b)
def jit_min_list_float(a, b):
# type: (List[float], List[float]) -> List[float]
def jit_min_list_float(a: List[float], b: List[float]) -> List[float]:
return min(a, b)
def jit_min_list_bool(a, b):
# type: (List[bool], List[bool]) -> List[bool]
def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
return min(a, b)
def run_tests(func, a, b):
@ -186,16 +177,13 @@ class TestList(JitTestCase):
[False, True], [False, False, True], [False, False, False]]
run_tests(jit_min_list_bool, args_left_bool, args_right_bool)
def jit_max_list(a, b):
# type: (List[int], List[int]) -> List[int]
def jit_max_list(a: List[int], b: List[int]) -> List[int]:
return max(a, b)
def jit_max_list_float(a, b):
# type: (List[float], List[float]) -> List[float]
def jit_max_list_float(a: List[float], b: List[float]) -> List[float]:
return max(a, b)
def jit_max_list_bool(a, b):
# type: (List[bool], List[bool]) -> List[bool]
def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
return max(a, b)
args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]]
@ -365,8 +353,7 @@ class TestList(JitTestCase):
t2 = scope['func']()
self.assertEqual(t1, t2)
def test_fail(x):
# type: (List[Tensor]) -> List[Tensor]
def test_fail(x: List[Tensor]) -> List[Tensor]:
x.sort()
return x
@ -472,8 +459,7 @@ class TestList(JitTestCase):
self.checkScript(test_append, ())
def test_comprehensions_basic(self):
def comp(l):
# type: (List[int]) -> List[int]
def comp(l: List[int]) -> List[int]:
n = [x * 3 for x in l]
return n
@ -482,8 +468,7 @@ class TestList(JitTestCase):
self.checkScript(comp, ([1, 2, 3],))
def test_comprehensions_basic_float(self):
def comp(l):
# type: (List[float]) -> List[float]
def comp(l: List[float]) -> List[float]:
n = [x * 3 for x in l]
return n
@ -492,8 +477,7 @@ class TestList(JitTestCase):
def test_comprehensions_two_comps(self):
@torch.jit.script
def comp(l1, l2):
# type: (List[int], List[int]) -> List[int]
def comp(l1: List[int], l2: List[int]) -> List[int]:
n = [x * 3 for x in l1]
n2 = [x + 2 for x in l2]
@ -502,8 +486,7 @@ class TestList(JitTestCase):
self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7])
def test_comprehension_out_type_not_in_type(self):
def list_cast():
# type: () -> int
def list_cast() -> int:
li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]]
return li[0] + li[1] + li[2]
@ -513,15 +496,13 @@ class TestList(JitTestCase):
def test_func(fn, inputs):
self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs))
def foo(names, results):
# type: (List[int], List[int]) -> List[Tuple[int, int]]
def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]:
return [(k + 5, v - 2) for k, v in zip(names, results)]
test_func(foo, ([1, 2, 4], [4, 7, 9]))
test_func(foo, ([5], [4, 7, 9]))
def fn(x):
# type: (int) -> List[int]
def fn(x: int) -> List[int]:
return [i for i in range(x)] # noqa: C416
test_func(fn, (9,))
@ -601,8 +582,7 @@ class TestList(JitTestCase):
def test_mutable_list_function_inline(self):
@torch.jit.script
def bar(y):
# type: (List[int]) -> None
def bar(y: List[int]) -> None:
y.append(4)
@torch.jit.script
@ -888,8 +868,7 @@ class TestList(JitTestCase):
def test_extend_list_mutable(self):
@torch.jit.script
def extend_list(a, b):
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]:
a.extend(b)
return a
@ -900,8 +879,7 @@ class TestList(JitTestCase):
def test_extend_list_immutable(self):
@torch.jit.script
def extend_list(a, b):
# type: (List[int], List[int]) -> List[int]
def extend_list(a: List[int], b: List[int]) -> List[int]:
a.extend(b)
return a
@ -912,8 +890,7 @@ class TestList(JitTestCase):
def test_copy_list_mutable(self):
@torch.jit.script
def copy_list(a):
# type: (List[Tensor]) -> List[Tensor]
def copy_list(a: List[Tensor]) -> List[Tensor]:
return a.copy()
for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
@ -921,36 +898,29 @@ class TestList(JitTestCase):
def test_copy_list_immutable(self):
@torch.jit.script
def copy_list(a):
# type: (List[int]) -> List[int]
def copy_list(a: List[int]) -> List[int]:
return a.copy()
for l in [[], [1], [1, 2, 3]]:
self.assertEqual(copy_list(l), l)
def test_min_max_single_list(self):
def min_intlist(li):
# type: (List[int]) -> int
def min_intlist(li: List[int]) -> int:
return min(li)
def max_intlist(li):
# type: (List[int]) -> int
def max_intlist(li: List[int]) -> int:
return max(li)
def min_boollist(li):
# type: (List[bool]) -> bool
def min_boollist(li: List[bool]) -> bool:
return min(li)
def max_boollist(li):
# type: (List[bool]) -> bool
def max_boollist(li: List[bool]) -> bool:
return max(li)
def min_floatlist(li):
# type: (List[float]) -> float
def min_floatlist(li: List[float]) -> float:
return min(li)
def max_floatlist(li):
# type: (List[float]) -> float
def max_floatlist(li: List[float]) -> float:
return max(li)
@ -980,23 +950,19 @@ class TestList(JitTestCase):
"""
Boolean dtype unit tests.
"""
def to_list_bool_0D(x):
# type: (torch.Tensor) -> bool
def to_list_bool_0D(x: torch.Tensor) -> bool:
li = torch.jit.annotate(bool, x.tolist())
return li
def to_list_bool_1D(x):
# type: (torch.Tensor) -> List[bool]
def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
li = torch.jit.annotate(List[bool], x.tolist())
return li
def to_list_bool_2D(x):
# type: (torch.Tensor) -> List[List[bool]]
def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]:
li = torch.jit.annotate(List[List[bool]], x.tolist())
return li
def to_list_bool_3D(x):
# type: (torch.Tensor) -> List[List[List[bool]]]
def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]:
li = torch.jit.annotate(List[List[List[bool]]], x.tolist())
return li
@ -1021,23 +987,19 @@ class TestList(JitTestCase):
"""
Int dtype unit tests.
"""
def to_list_int_0D(x):
# type: (torch.Tensor) -> int
def to_list_int_0D(x: torch.Tensor) -> int:
li = torch.jit.annotate(int, x.tolist())
return li
def to_list_int_1D(x):
# type: (torch.Tensor) -> List[int]
def to_list_int_1D(x: torch.Tensor) -> List[int]:
li = torch.jit.annotate(List[int], x.tolist())
return li
def to_list_int_2D(x):
# type: (torch.Tensor) -> List[List[int]]
def to_list_int_2D(x: torch.Tensor) -> List[List[int]]:
li = torch.jit.annotate(List[List[int]], x.tolist())
return li
def to_list_int_3D(x):
# type: (torch.Tensor) -> List[List[List[int]]]
def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]:
li = torch.jit.annotate(List[List[List[int]]], x.tolist())
return li
@ -1058,23 +1020,19 @@ class TestList(JitTestCase):
"""
Float dtype unit tests.
"""
def to_list_float_0D(x):
# type: (torch.Tensor) -> float
def to_list_float_0D(x: torch.Tensor) -> float:
li = torch.jit.annotate(float, x.tolist())
return li
def to_list_float_1D(x):
# type: (torch.Tensor) -> List[float]
def to_list_float_1D(x: torch.Tensor) -> List[float]:
li = torch.jit.annotate(List[float], x.tolist())
return li
def to_list_float_2D(x):
# type: (torch.Tensor) -> List[List[float]]
def to_list_float_2D(x: torch.Tensor) -> List[List[float]]:
li = torch.jit.annotate(List[List[float]], x.tolist())
return li
def to_list_float_3D(x):
# type: (torch.Tensor) -> List[List[List[float]]]
def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]:
li = torch.jit.annotate(List[List[List[float]]], x.tolist())
return li
@ -1099,28 +1057,23 @@ class TestList(JitTestCase):
- type annotation with the wrong dimension
- type annotation with scalar type that doesn't match the input scalar type
"""
def to_list_missing_type_annotation(x):
# type: (torch.Tensor) -> List[float]
def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]:
li = x.tolist()
return li
def to_list_incorrect_type_annotation(x):
# type: (torch.Tensor) -> List[float]
def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]:
li = torch.jit.annotate(float, x.tolist())
return li
def to_list_unsupported_type_annotation(x):
# type: (torch.Tensor) -> List[float]
def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]:
li = torch.jit.annotate(List[str], x.tolist())
return li
def to_list_type_annotation_wrong_dim(x):
# type: (torch.Tensor) -> List[List[float]]
def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]:
li = torch.jit.annotate(List[List[float]], x.tolist())
return li
def to_list_type_annotation_incorrect_scalar_type(x):
# type: (torch.Tensor) -> List[float]
def to_list_type_annotation_incorrect_scalar_type(x: torch.Tensor) -> List[float]:
li = torch.jit.annotate(List[float], x.tolist())
return li
@ -1164,18 +1117,15 @@ class TestList(JitTestCase):
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
self.skipTest("CUDA is not available")
def to_list_bool_1D(x):
# type: (torch.Tensor) -> List[bool]
def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
li = torch.jit.annotate(List[bool], x.tolist())
return li
def to_list_int_1D(x):
# type: (torch.Tensor) -> List[int]
def to_list_int_1D(x: torch.Tensor) -> List[int]:
li = torch.jit.annotate(List[int], x.tolist())
return li
def to_list_float_1D(x):
# type: (torch.Tensor) -> List[float]
def to_list_float_1D(x: torch.Tensor) -> List[float]:
li = torch.jit.annotate(List[float], x.tolist())
return li
@ -1187,8 +1137,7 @@ class TestList(JitTestCase):
5, dtype=torch.double).cuda(),))
def test_no_element_type_annotation(self):
def fn_with_comment(x):
# type: (torch.Tensor) -> List
def fn_with_comment(x: torch.Tensor) -> List:
a: List = x.tolist()
return a
@ -1230,8 +1179,7 @@ class TestDict(JitTestCase):
def inputs():
return {'hi': 2, 'bye': 3}
def fn(x):
# type: (Dict[str, int]) -> Dict[str, int]
def fn(x: Dict[str, int]) -> Dict[str, int]:
del x['hi']
return x
@ -1247,8 +1195,7 @@ class TestDict(JitTestCase):
def test_keys(self):
@torch.jit.script
def keys(x):
# type: (Dict[str, Tensor]) -> List[str]
def keys(x: Dict[str, Tensor]) -> List[str]:
return list(x.keys())
self.assertEqual(set(keys(self.dict())), set(self.dict().keys()))
@ -1263,30 +1210,26 @@ class TestDict(JitTestCase):
def test_values(self):
@torch.jit.script
def values(x):
# type: (Dict[str, Tensor]) -> List[Tensor]
def values(x: Dict[str, Tensor]) -> List[Tensor]:
return list(x.values())
the_dict = self.dict()
self.assertEqual(set(values(the_dict)), set(the_dict.values()))
def test_len(self):
def length(x):
# type: (Dict[str, Tensor]) -> int
def length(x: Dict[str, Tensor]) -> int:
return len(x)
self.checkScript(length, (self.dict(),))
def test_copy(self):
def func(x):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
return x.copy()
self.checkScript(func, (self.dict(),))
def test_items(self):
def func(x):
# type: (Dict[str, Tensor]) -> List[Tuple[str, Tensor]]
def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]:
return x.items()
# The value returned by Python is in arbitrary order, so we can't use
@ -1301,8 +1244,7 @@ class TestDict(JitTestCase):
self.assertTrue(item in script_out)
def test_pop(self):
def pop(x, key):
# type: (Dict[str, Tensor], str) -> Tuple[Tensor, Dict[str, Tensor]]
def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]:
return x.pop(key), x
# checkScript doesn't copy the inputs, so we can't use it since this mutates
@ -1318,16 +1260,14 @@ class TestDict(JitTestCase):
torch.jit.script(pop)(self.dict(), 'x')
def default_pop(x, key, default):
# type: (Dict[str, Tensor], str, Tensor) -> Tuple[Tensor, Dict[str, Tensor]]
def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
return x.pop(key, default), x
tester(default_pop, 'a', torch.randn(2, 2))
tester(default_pop, 'x', torch.randn(2, 2))
def test_setdefault(self):
def setdefault(x, key, default):
# type: (Dict[str, Tensor], str, Tensor) -> Dict[str, Tensor]
def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]:
x.setdefault(key, default)
return x
@ -1335,8 +1275,7 @@ class TestDict(JitTestCase):
self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2)))
def test_update(self):
def update(a, b):
# type: (Dict[str, Tensor], Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]
def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
a.update(b)
return a, b
@ -1353,8 +1292,7 @@ class TestDict(JitTestCase):
self.checkScript(foo, ())
def test_aug_assign(self):
def aug_assign_dict_tensor(a):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]:
a['a'] += 1
a['b'] -= 12
a['c'] *= 122
@ -1362,8 +1300,7 @@ class TestDict(JitTestCase):
a['c'] %= 2
return a
def aug_assign_dict_prim(a):
# type: (Dict[str, float]) -> Dict[str, float]
def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]:
a['a'] += 3.4
a['b'] -= 2.4
a['c'] *= 3.0
@ -1376,8 +1313,7 @@ class TestDict(JitTestCase):
def test_popitem(self):
@torch.jit.script
def popitem(x):
# type: (Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]
def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]:
item = x.popitem()
return item, x
@ -1395,65 +1331,56 @@ class TestDict(JitTestCase):
self.assertTrue(isinstance(script_out[0][1], torch.Tensor))
def test_clear(self):
def clear(x):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
x.clear()
return x
self.checkScript(clear, (self.dict(),))
def test_get(self):
def get(x, key):
# type: (Dict[str, Tensor], str) -> Optional[Tensor]
def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
return x.get(key)
self.checkScript(get, (self.dict(), 'a'))
self.checkScript(get, (self.dict(), "doesn't exist"))
def get_default(x, key):
# type: (Dict[str, Tensor], str) -> Optional[Tensor]
def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
return x.get(key, torch.randn(2, 2))
self.checkScript(get, (self.dict(), 'a'))
self.checkScript(get, (self.dict(), "doesn't exist"))
def test_get_boolkey(self):
def get(x, key):
# type: (Dict[bool, int], bool) -> Optional[int]
def get(x: Dict[bool, int], key: bool) -> Optional[int]:
return x.get(key)
self.checkScript(get, (self.dict_bool(), True))
self.checkScript(get, (self.dict_bool(), False))
def get_default(x, key):
# type: (Dict[bool, int], bool) -> int
def get_default(x: Dict[bool, int], key: bool) -> int:
return x.get(key, 42)
self.checkScript(get_default, (self.dict_bool(), True))
self.checkScript(get_default, (self.dict_bool(), False))
def test_basic(self):
def simple(x):
# type: (Dict[str, int]) -> Dict[str, int]
def simple(x: Dict[str, int]) -> Dict[str, int]:
return x
self.checkScript(simple, ({'item': 20, 'other_item': 120},))
def index(x):
# type: (Dict[str, int]) -> int
def index(x: Dict[str, int]) -> int:
return x['item']
self.checkScript(index, ({'item': 20, 'other_item': 120},))
def type_default():
# type: () -> Dict[str, Tensor]
def type_default() -> Dict[str, Tensor]:
return {}
self.checkScript(type_default, ())
@torch.jit.script
def missing_index(x):
# type: (Dict[str, int]) -> int
def missing_index(x: Dict[str, int]) -> int:
return x['dne']
with self.assertRaisesRegex(RuntimeError, "KeyError"):
@ -1475,16 +1402,14 @@ class TestDict(JitTestCase):
'''))
self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3())
def list_of_dicts():
# type: () -> List[Dict[str, Tensor]]
def list_of_dicts() -> List[Dict[str, Tensor]]:
return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}]
self.checkScript(list_of_dicts, ())
def test_mutability(self):
@torch.jit.script
def fn():
# type: () -> Dict[str, int]
def fn() -> Dict[str, int]:
a = torch.jit.annotate(Dict[str, int], {})
a['ok'] = 10
return a
@ -1494,14 +1419,12 @@ class TestDict(JitTestCase):
def test_key_type(self):
with self.assertRaisesRegex(RuntimeError, "but instead found type"):
@torch.jit.script
def fn(a):
# type: (Dict[str, int]) -> int
def fn(a: Dict[str, int]) -> int:
return a[None]
def test_loop(self):
@torch.jit.script
def fn(x):
# type: (int) -> Dict[str, int]
def fn(x: int) -> Dict[str, int]:
a = torch.jit.annotate(Dict[str, int], {})
for i in range(x):
a['ok'] = i
@ -1520,16 +1443,14 @@ class TestDict(JitTestCase):
self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
def test_membership(self):
def fn(x, y):
# type: (Dict[int, int], int) -> int
def fn(x: Dict[int, int], y: int) -> int:
return x.get(y, 3)
d = {1: 2, 3: 4}
self.checkScript(fn, (d, 3))
self.checkScript(fn, (d, 2))
def optional(x, y):
# type: (Dict[int, int], int) -> bool
def optional(x: Dict[int, int], y: int) -> bool:
res = x.get(y)
return res is None
@ -1538,18 +1459,15 @@ class TestDict(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"):
@torch.jit.script
def bad_types(x, y):
# type: (Dict[int, int], int) -> int
def bad_types(x: Dict[int, int], y: int) -> int:
return x.get(y) # noqa: T484
def test_dict_to_python(self):
@torch.jit.ignore
def python_lookup(my_dict, keys):
# type: (Dict[str, int], List[str]) -> List[int]
def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
return [my_dict[k] for k in keys]
def fn(my_dict, keys):
# type: (Dict[str, int], List[str]) -> List[int]
def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
return python_lookup(my_dict, keys)
a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
@ -1601,8 +1519,7 @@ class TestDict(JitTestCase):
key and value types produces an error.
"""
# This function uses a type comment.
def fn_with_comment(input):
# type: (Dict) -> Any
def fn_with_comment(input: Dict) -> Any:
return input
# This function uses Python3 style type annotations.

View File

@ -6,6 +6,7 @@ import torch
import torch.nn as nn
import os
import sys
from torch import Tensor
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
@ -22,36 +23,30 @@ class OrigModule(nn.Module):
def __init__(self):
super(OrigModule, self).__init__()
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 + inp2 + 1
def two(self, input):
# type: (Tensor) -> Tensor
def two(self, input: Tensor) -> Tensor:
return input + 2
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return input + self.one(input, input) + 1
class NewModule(nn.Module):
def __init__(self):
super(NewModule, self).__init__()
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.one(input, input + 1)
class TestModuleInterface(JitTestCase):
def test_not_submodule_interface_call(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
class TestNotModuleInterfaceCall(nn.Module):
@ -61,8 +56,7 @@ class TestModuleInterface(JitTestCase):
super(TestNotModuleInterfaceCall, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.two(input)
with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"):
@ -72,64 +66,51 @@ class TestModuleInterface(JitTestCase):
global OneTwoModule, OneTwoClass
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: Tensor, y: Tensor) -> Tensor:
pass
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: Tensor) -> Tensor:
pass
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
pass
@torch.jit.interface
class OneTwoClass(object):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: Tensor, y: Tensor) -> Tensor:
pass
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: Tensor) -> Tensor:
pass
class FooMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: Tensor, y: Tensor) -> Tensor:
return x + y
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: Tensor) -> Tensor:
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
return self.one(self.two(x), x)
class BarMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: Tensor, y: Tensor) -> Tensor:
return x * y
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: Tensor) -> Tensor:
return 2 / x
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
return self.two(self.one(x, x))
@torch.jit.export
def forward2(self, x):
# type: (Tensor) -> Tensor
def forward2(self, x: Tensor) -> Tensor:
return self.two(self.one(x, x)) + 1
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
return mod_list[0].forward(x) + mod_list[1].forward(x)
def use_class_interface(mod_list, x):
# type: (List[OneTwoClass], Tensor) -> Tensor
def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor:
return mod_list[0].two(x) + mod_list[1].one(x, x)
scripted_foo_mod = torch.jit.script(FooMod())
@ -139,8 +120,7 @@ class TestModuleInterface(JitTestCase):
self.checkScript(use_class_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
def call_module_interface_on_other_method(mod_interface, x):
# type: (OneTwoModule, Tensor) -> Tensor
def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor:
return mod_interface.forward2(x)
# ensure error out when we call the module on the method other than the interface specified.
@ -178,35 +158,28 @@ class TestModuleInterface(JitTestCase):
global OneTwoModule
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: Tensor, y: Tensor) -> Tensor:
pass
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: Tensor) -> Tensor:
pass
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
def as_module_interface(x):
# type: (OneTwoModule) -> OneTwoModule
def as_module_interface(x: OneTwoModule) -> OneTwoModule:
return x
@torch.jit.script
class Foo(object):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: Tensor, y: Tensor) -> Tensor:
return x + y
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: Tensor) -> Tensor:
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
return self.one(self.two(x), x)
# check class object is not a subtype of module interface
@ -214,12 +187,10 @@ class TestModuleInterface(JitTestCase):
as_module_interface(Foo())
class WrongMod(nn.Module):
def two(self, x):
# type: (int) -> int
def two(self, x: int) -> int:
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
return x + torch.randn(3, self.two(3))
scripted_wrong_mod = torch.jit.script(WrongMod())
@ -270,19 +241,16 @@ class TestModuleInterface(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
@torch.jit.interface
class InheritMod(nn.ReLU):
def three(self, x):
# type: (Tensor) -> Tensor
def three(self, x: Tensor) -> Tensor:
return 3 * x
def test_module_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
pass
class TestModule(nn.Module):
@ -292,8 +260,7 @@ class TestModuleInterface(JitTestCase):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
@ -311,20 +278,17 @@ class TestModuleInterface(JitTestCase):
def test_module_swap_wrong_module(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
pass
class NewModuleWrong(nn.Module):
def __init__(self):
super(NewModuleWrong, self).__init__()
def forward(self, input):
# type: (int) -> int
def forward(self, input: int) -> int:
return input + 1
class TestModule(nn.Module):
@ -334,8 +298,7 @@ class TestModuleInterface(JitTestCase):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
@ -346,12 +309,10 @@ class TestModuleInterface(JitTestCase):
def test_module_swap_no_lazy_compile(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
pass
class TestModule(nn.Module):
@ -361,20 +322,17 @@ class TestModuleInterface(JitTestCase):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
class NewModuleMethodNotLazyCompile(nn.Module):
def __init__(self):
super(NewModuleMethodNotLazyCompile, self).__init__()
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return input + 1
scripted_mod = torch.jit.script(TestModule())
@ -388,12 +346,10 @@ class TestModuleInterface(JitTestCase):
super(NewModuleMethodManualExport, self).__init__()
@torch.jit.export
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return input + 1
scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport())
@ -407,8 +363,7 @@ class TestModuleInterface(JitTestCase):
super(TestNoModuleInterface, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod(input)
scripted_no_module_interface = torch.jit.script(TestNoModuleInterface())
@ -423,12 +378,10 @@ class TestModuleInterface(JitTestCase):
def test_script_module_as_interface_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
pass
class OrigScriptModule(torch.jit.ScriptModule):
@ -436,13 +389,11 @@ class TestModuleInterface(JitTestCase):
super(OrigScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 + inp2 + 1
@torch.jit.script_method
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return input + self.one(input, input) + 1
class NewScriptModule(torch.jit.ScriptModule):
@ -450,13 +401,11 @@ class TestModuleInterface(JitTestCase):
super(NewScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
@torch.jit.script_method
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.one(input, input + 1)
class TestNNModuleWithScriptModule(nn.Module):
@ -466,8 +415,7 @@ class TestModuleInterface(JitTestCase):
super(TestNNModuleWithScriptModule, self).__init__()
self.proxy_mod = OrigScriptModule()
def forward(self, input):
# type: (Tensor) -> Tensor
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
input = torch.randn(3, 4)
@ -498,8 +446,7 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> int
def forward(self, x: Tensor) -> int:
pass
class TestModule(torch.nn.Module):
@ -546,8 +493,7 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> int
def forward(self, x: Tensor) -> int:
pass
class TestModule(torch.nn.Module):
@ -590,8 +536,7 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
pass
class TestModule(torch.nn.Module):
@ -636,8 +581,7 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
pass
class TestModule(torch.nn.Module):
@ -679,8 +623,7 @@ class TestModuleInterface(JitTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: Tensor) -> Tensor:
pass
class TestModule(torch.nn.Module):
@ -714,8 +657,7 @@ class TestModuleInterface(JitTestCase):
def test_module_apis_interface(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
class TestModule(nn.Module):

View File

@ -2,10 +2,11 @@ import os
import sys
import typing
import typing_extensions
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.testing import FileCheck
from collections import OrderedDict
@ -284,8 +285,7 @@ class TestRecursiveScript(JitTestCase):
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
def test_class_compile(self):
def other_fn(a, b):
# type: (int, Tensor) -> Tensor
def other_fn(a: int, b: Tensor) -> Tensor:
return a * b
class B(object):
@ -307,8 +307,7 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(N(), (torch.randn(2, 2),))
def test_error_stack(self):
def d(x):
# type: (int) -> int
def d(x: int) -> int:
return x + 10
def c(x):
@ -331,8 +330,7 @@ class TestRecursiveScript(JitTestCase):
checker.run(str(e))
def test_error_stack_module(self):
def d(x):
# type: (int) -> int
def d(x: int) -> int:
return x + 10
def c(x):
@ -565,8 +563,7 @@ class TestRecursiveScript(JitTestCase):
self.a = 4
self.inner = Inner2()
def __setstate__(self, obj):
# type: (Tuple[int, Inner2]) -> None
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
a, inner = obj
self.a = a
self.inner = inner

View File

@ -1,12 +1,14 @@
import os
import io
import pathlib
import sys
import random
import torch
from itertools import product as product
from torch.testing._internal.common_utils import TemporaryFileName
from typing import NamedTuple, Optional
import io
import os
import pathlib
import random
import sys
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -680,8 +682,7 @@ class TestSaveLoad(JitTestCase):
"""
@torch.jit.interface
class MyInterface(object):
def bar(self, x):
# type: (Tensor) -> Tensor
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
@ -711,8 +712,7 @@ class TestSaveLoad(JitTestCase):
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x):
# type: (Tensor) -> Tensor
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa: F811
@ -767,8 +767,7 @@ class TestSaveLoad(JitTestCase):
@torch.jit.interface
class MyInterface(object):
def bar(self, x):
# type: (Tensor) -> Tensor
def bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script
@ -809,8 +808,7 @@ class TestSaveLoad(JitTestCase):
@torch.jit.interface
class MyInterface(object):
def not_bar(self, x):
# type: (Tensor) -> Tensor
def not_bar(self, x: Tensor) -> Tensor:
pass
@torch.jit.script # noqa F811

View File

@ -1,7 +1,7 @@
import os
import sys
from typing import Any
from typing import Any, List
import torch
from torch.testing._internal.jit_utils import JitTestCase
@ -50,8 +50,7 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
def test_basic(x):
# type: (Tensor) -> Tensor
def test_basic(x: torch.Tensor) -> torch.Tensor:
"""Basic test with one with-statement."""
c = Context(1)
@ -62,8 +61,7 @@ class TestWith(JitTestCase):
y *= c.count
return y
def test_pass(x):
# type: (Tensor) -> Tensor
def test_pass(x: torch.Tensor) -> torch.Tensor:
"""
Test with a pass statement inside a with-statement. Although
the body of the with is empty, __enter__ and __exit__ should
@ -77,8 +75,7 @@ class TestWith(JitTestCase):
x *= c.count
return x
def test_early_return(x, c):
# type: (Tensor, Context) -> Tensor
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that returning early from inside a with-statement works
as expected.
@ -90,8 +87,7 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_conditional_early_return(x, c):
# type: (Tensor, Context) -> Tensor
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that conditionally returning early from inside a with-statement works
as expected.
@ -104,8 +100,7 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_break(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that breaking early from inside a with-statement works
as expected.
@ -118,8 +113,7 @@ class TestWith(JitTestCase):
return x
def test_continue(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that using continue inside a with-statement works
as expected.
@ -132,8 +126,7 @@ class TestWith(JitTestCase):
return x
def test_serial(x):
# type: (Tensor) -> Tensor
def test_serial(x: torch.Tensor) -> torch.Tensor:
"""
Test two with-statements in a row.
"""
@ -147,8 +140,7 @@ class TestWith(JitTestCase):
return y
def test_nested(x):
# type: (Tensor) -> Tensor
def test_nested(x: torch.Tensor) -> torch.Tensor:
"""
Test nested with-statements.
"""
@ -162,8 +154,7 @@ class TestWith(JitTestCase):
return y
def test_combined(x):
# type: (Tensor) -> Tensor
def test_combined(x: torch.Tensor) -> torch.Tensor:
"""
Test a with-statement with multiple with items.
"""
@ -215,8 +206,7 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
def test_basic(x):
# type: (Tensor) -> Tensor
def test_basic(x: torch.Tensor) -> torch.Tensor:
"""Basic test with one with-statement."""
c = Context(1)
@ -227,8 +217,7 @@ class TestWith(JitTestCase):
y *= c.count
return y
def test_pass(x):
# type: (Tensor) -> Tensor
def test_pass(x: torch.Tensor) -> torch.Tensor:
"""
Test with a pass statement inside a with-statement. Although
the body of the with is empty, __enter__ and __exit__ should
@ -242,8 +231,7 @@ class TestWith(JitTestCase):
x *= c.count
return x
def test_early_return(x, c):
# type: (Tensor, Context) -> Tensor
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that returning early from inside a with-statement works
as expected.
@ -255,8 +243,7 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_conditional_early_return(x, c):
# type: (Tensor, Context) -> Tensor
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that conditionally returning early from inside a with-statement works
as expected.
@ -269,8 +256,7 @@ class TestWith(JitTestCase):
x = y + y
return x
def test_break(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that breaking early from inside a with-statement works
as expected.
@ -283,8 +269,7 @@ class TestWith(JitTestCase):
return x
def test_continue(x, c, l):
# type: (Tensor, Context, List[int]) -> Tensor
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that using continue inside a with-statement works
as expected.
@ -297,8 +282,7 @@ class TestWith(JitTestCase):
return x
def test_serial(x):
# type: (Tensor) -> Tensor
def test_serial(x: torch.Tensor) -> torch.Tensor:
"""
Test two with-statements in a row.
"""
@ -312,8 +296,7 @@ class TestWith(JitTestCase):
return y
def test_nested(x):
# type: (Tensor) -> Tensor
def test_nested(x: torch.Tensor) -> torch.Tensor:
"""
Test nested with-statements.
"""
@ -327,8 +310,7 @@ class TestWith(JitTestCase):
return y
def test_combined(x):
# type: (Tensor) -> Tensor
def test_combined(x: torch.Tensor) -> torch.Tensor:
"""
Test a with-statement with multiple with items.
"""
@ -381,13 +363,11 @@ class TestWith(JitTestCase):
self.count.sub_(0.3)
@torch.jit.script
def method_that_raises():
# type: () -> Tensor
def method_that_raises() -> torch.Tensor:
raise Exception("raised exception")
@torch.jit.script
def test_exception(x, c):
# type: (Tensor, Context) -> Tensor
def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test the case in which an exception is thrown while executing the body of a with-statement.
"""
@ -397,8 +377,7 @@ class TestWith(JitTestCase):
return x
@torch.jit.script
def test_exception_nested(x, c):
# type: (Tensor, Context) -> Tensor
def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test the case in which an exception is thrown while executing the body of a nested with-statement.
"""
@ -409,8 +388,7 @@ class TestWith(JitTestCase):
return x
@torch.jit.script
def with_that_raises(c):
# type: (Context) -> Tensor
def with_that_raises(c: Context) -> torch.Tensor:
a = torch.tensor([1])
with c as _:
@ -419,8 +397,7 @@ class TestWith(JitTestCase):
return a
@torch.jit.script
def test_exception_fn_call(x, c):
# type: (Tensor, Context) -> Tensor
def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test the case in which an exception is thrown while there are active with-statements in two different
frames.
@ -506,29 +483,25 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: int, tb: int):
pass
def test_no_enter_no_exit(x, c):
# type: (Tensor, NoEnterNoExit) -> Tensor
def test_no_enter_no_exit(x: torch.Tensor, c: NoEnterNoExit) -> torch.Tensor:
with c as _:
pass
return x
def test_bad_enter(x, c):
# type: (Tensor, BadEnter) -> Tensor
def test_bad_enter(x: torch.Tensor, c: BadEnter) -> torch.Tensor:
with c as _:
pass
return x
def test_bad_exit(x, c):
# type: (Tensor, BadExit) -> Tensor
def test_bad_exit(x: torch.Tensor, c: BadExit) -> torch.Tensor:
with c as _:
pass
return x
def test_exit_incorrect_types(x, c):
# type: (Tensor, ExitIncorrectTypes) -> Tensor
def test_exit_incorrect_types(x: torch.Tensor, c: ExitIncorrectTypes) -> torch.Tensor:
with c as _:
pass
@ -565,8 +538,7 @@ class TestWith(JitTestCase):
"""
# Basic no_grad test.
def test_no_grad(x, y):
# type: (Tensor, Tensor) -> Tensor
def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
w = x + y
@ -583,8 +555,7 @@ class TestWith(JitTestCase):
# Test assignment of a grad-less Tensor to a Tensor with gradients
# in a no_grad block.
def test_no_grad_assignment(x, y):
# type: (Tensor, Tensor) -> Tensor
def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x[0] = y
@ -603,13 +574,11 @@ class TestWith(JitTestCase):
super().__init__()
@torch.jit.ignore
def adder(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
w = x + y
return w
def forward(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
w = self.adder(x, y)
@ -625,8 +594,7 @@ class TestWith(JitTestCase):
Check that torch.autograd.profiler.record_function context manager is
torchscriptable.
"""
def with_rf(x, y):
# type: (Tensor, Tensor) -> Tensor
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.autograd.profiler.record_function("foo"):
# Nested record_function.
with torch.autograd.profiler.record_function("nested"):