From 8c25b9701bb8aa480678fe974d160020cc24538c Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 12 Jan 2021 16:45:16 -0800 Subject: [PATCH] Type annotations in test/jit (#50293) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50293 Switching to type annotations for improved safety and import tracking. Test Plan: Sandcastle tests Reviewed By: xush6528 Differential Revision: D25853949 fbshipit-source-id: fb873587bb521a0a55021ee4d34d1b05ea8f000d --- test/jit/test_async.py | 8 +- test/jit/test_builtins.py | 14 +- test/jit/test_list_dict.py | 255 ++++++++++-------------------- test/jit/test_module_interface.py | 178 +++++++-------------- test/jit/test_recursive_script.py | 15 +- test/jit/test_save_load.py | 28 ++-- test/jit/test_with.py | 98 ++++-------- 7 files changed, 206 insertions(+), 390 deletions(-) diff --git a/test/jit/test_async.py b/test/jit/test_async.py index 7a70a4c5a655..b4b6b8e294f7 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -5,7 +5,7 @@ import sys import torch import torch.nn as nn -from typing import Any +from typing import Any, Tuple # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -41,8 +41,7 @@ class TestAsync(JitTestCase): def test_async_parsing(self): @torch.jit.script - def foo(x): - # type: (Tensor) -> List[Tensor] + def foo(x: Tensor) -> List[Tensor]: return [torch.neg(x), x.t()] @torch.jit.script @@ -257,8 +256,7 @@ class TestAsync(JitTestCase): self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) @torch.jit.script_method - def forward(self, x): - # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor] + def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: future1 = torch.jit._fork(self.traced, x) future2 = torch.jit._fork(torch.neg, x) diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index 04991f72c352..b5a0dd8599a6 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -2,7 +2,7 @@ import os import sys import inspect import unittest -from typing import List +from typing import Dict, List import torch @@ -78,8 +78,7 @@ class TestBuiltins(JitTestCase): torch.jit.script(Mod()) def test_del(self): - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: a = x * 2 del a return x @@ -109,16 +108,14 @@ class TestBuiltins(JitTestCase): return a def test_del_multiple_operands(self): - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: a, b, c = x[0], x[1], x[2] del a, b, c return x self.checkScript(fn, ([1, 2, 3],)) - def del_list_multiple_operands(x): - # type: (List[int]) -> List[int] + def del_list_multiple_operands(x: List[int]) -> List[int]: del x[0], x[1] return x @@ -126,8 +123,7 @@ class TestBuiltins(JitTestCase): jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) self.assertEquals(py_out, jit_out) - def del_dict_multiple_operands(x): - # type: (Dict[str, int]) -> Dict[str, int] + def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: del x['hi'], x['there'] return x diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 2bc24a57751d..9d6be7806628 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -1,10 +1,11 @@ import os import sys import inspect -from typing import Dict, List, Any +from typing import Any, Dict, List, Optional, Tuple from textwrap import dedent from collections import OrderedDict +from torch import Tensor import torch from torch.testing import FileCheck @@ -20,22 +21,19 @@ if __name__ == '__main__': class TestList(JitTestCase): def test_in_check(self): - def int_in(x): - # type: (List[int]) -> bool + def int_in(x: List[int]) -> bool: return 2 in x self.checkScript(int_in, ([1, 2, 3],)) self.checkScript(int_in, ([1, 3, 3],)) - def float_in(x): - # type: (List[float]) -> bool + def float_in(x: List[float]) -> bool: return 2. in x self.checkScript(float_in, ([1., 2., 3.],)) self.checkScript(float_in, ([1., 3., 3.],)) - def str_in(x): - # type: (List[str]) -> bool + def str_in(x: List[str]) -> bool: return 'hi' in x self.checkScript(str_in, (['not', 'here'],)) @@ -100,8 +98,7 @@ class TestList(JitTestCase): def inputs(): return [1, 2, 3, 4] - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: del x[1] return x @@ -114,8 +111,7 @@ class TestList(JitTestCase): self.assertEqual(torch.jit.script(fn)(inputs()), python_out) @torch.jit.script - def fn2(x): - # type: (List[int]) -> List[int] + def fn2(x: List[int]) -> List[int]: del x[100] return x @@ -124,8 +120,7 @@ class TestList(JitTestCase): with self.assertRaisesRegex(RuntimeError, "deletion at a single index"): @torch.jit.script - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: del x[1:3] return x @@ -149,23 +144,19 @@ class TestList(JitTestCase): FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph) def test_min_bool_list(self): - def jit_min_list(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]: return min(a, b) self.checkScript(jit_min_list, ([True, False], [False, True])) def test_min_max_list(self): - def jit_min_list(a, b): - # type: (List[int], List[int]) -> List[int] + def jit_min_list(a: List[int], b: List[int]) -> List[int]: return min(a, b) - def jit_min_list_float(a, b): - # type: (List[float], List[float]) -> List[float] + def jit_min_list_float(a: List[float], b: List[float]) -> List[float]: return min(a, b) - def jit_min_list_bool(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]: return min(a, b) def run_tests(func, a, b): @@ -186,16 +177,13 @@ class TestList(JitTestCase): [False, True], [False, False, True], [False, False, False]] run_tests(jit_min_list_bool, args_left_bool, args_right_bool) - def jit_max_list(a, b): - # type: (List[int], List[int]) -> List[int] + def jit_max_list(a: List[int], b: List[int]) -> List[int]: return max(a, b) - def jit_max_list_float(a, b): - # type: (List[float], List[float]) -> List[float] + def jit_max_list_float(a: List[float], b: List[float]) -> List[float]: return max(a, b) - def jit_max_list_bool(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]: return max(a, b) args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]] @@ -365,8 +353,7 @@ class TestList(JitTestCase): t2 = scope['func']() self.assertEqual(t1, t2) - def test_fail(x): - # type: (List[Tensor]) -> List[Tensor] + def test_fail(x: List[Tensor]) -> List[Tensor]: x.sort() return x @@ -472,8 +459,7 @@ class TestList(JitTestCase): self.checkScript(test_append, ()) def test_comprehensions_basic(self): - def comp(l): - # type: (List[int]) -> List[int] + def comp(l: List[int]) -> List[int]: n = [x * 3 for x in l] return n @@ -482,8 +468,7 @@ class TestList(JitTestCase): self.checkScript(comp, ([1, 2, 3],)) def test_comprehensions_basic_float(self): - def comp(l): - # type: (List[float]) -> List[float] + def comp(l: List[float]) -> List[float]: n = [x * 3 for x in l] return n @@ -492,8 +477,7 @@ class TestList(JitTestCase): def test_comprehensions_two_comps(self): @torch.jit.script - def comp(l1, l2): - # type: (List[int], List[int]) -> List[int] + def comp(l1: List[int], l2: List[int]) -> List[int]: n = [x * 3 for x in l1] n2 = [x + 2 for x in l2] @@ -502,8 +486,7 @@ class TestList(JitTestCase): self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7]) def test_comprehension_out_type_not_in_type(self): - def list_cast(): - # type: () -> int + def list_cast() -> int: li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]] return li[0] + li[1] + li[2] @@ -513,15 +496,13 @@ class TestList(JitTestCase): def test_func(fn, inputs): self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) - def foo(names, results): - # type: (List[int], List[int]) -> List[Tuple[int, int]] + def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]: return [(k + 5, v - 2) for k, v in zip(names, results)] test_func(foo, ([1, 2, 4], [4, 7, 9])) test_func(foo, ([5], [4, 7, 9])) - def fn(x): - # type: (int) -> List[int] + def fn(x: int) -> List[int]: return [i for i in range(x)] # noqa: C416 test_func(fn, (9,)) @@ -601,8 +582,7 @@ class TestList(JitTestCase): def test_mutable_list_function_inline(self): @torch.jit.script - def bar(y): - # type: (List[int]) -> None + def bar(y: List[int]) -> None: y.append(4) @torch.jit.script @@ -888,8 +868,7 @@ class TestList(JitTestCase): def test_extend_list_mutable(self): @torch.jit.script - def extend_list(a, b): - # type: (List[Tensor], List[Tensor]) -> List[Tensor] + def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]: a.extend(b) return a @@ -900,8 +879,7 @@ class TestList(JitTestCase): def test_extend_list_immutable(self): @torch.jit.script - def extend_list(a, b): - # type: (List[int], List[int]) -> List[int] + def extend_list(a: List[int], b: List[int]) -> List[int]: a.extend(b) return a @@ -912,8 +890,7 @@ class TestList(JitTestCase): def test_copy_list_mutable(self): @torch.jit.script - def copy_list(a): - # type: (List[Tensor]) -> List[Tensor] + def copy_list(a: List[Tensor]) -> List[Tensor]: return a.copy() for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: @@ -921,36 +898,29 @@ class TestList(JitTestCase): def test_copy_list_immutable(self): @torch.jit.script - def copy_list(a): - # type: (List[int]) -> List[int] + def copy_list(a: List[int]) -> List[int]: return a.copy() for l in [[], [1], [1, 2, 3]]: self.assertEqual(copy_list(l), l) def test_min_max_single_list(self): - def min_intlist(li): - # type: (List[int]) -> int + def min_intlist(li: List[int]) -> int: return min(li) - def max_intlist(li): - # type: (List[int]) -> int + def max_intlist(li: List[int]) -> int: return max(li) - def min_boollist(li): - # type: (List[bool]) -> bool + def min_boollist(li: List[bool]) -> bool: return min(li) - def max_boollist(li): - # type: (List[bool]) -> bool + def max_boollist(li: List[bool]) -> bool: return max(li) - def min_floatlist(li): - # type: (List[float]) -> float + def min_floatlist(li: List[float]) -> float: return min(li) - def max_floatlist(li): - # type: (List[float]) -> float + def max_floatlist(li: List[float]) -> float: return max(li) @@ -980,23 +950,19 @@ class TestList(JitTestCase): """ Boolean dtype unit tests. """ - def to_list_bool_0D(x): - # type: (torch.Tensor) -> bool + def to_list_bool_0D(x: torch.Tensor) -> bool: li = torch.jit.annotate(bool, x.tolist()) return li - def to_list_bool_1D(x): - # type: (torch.Tensor) -> List[bool] + def to_list_bool_1D(x: torch.Tensor) -> List[bool]: li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_bool_2D(x): - # type: (torch.Tensor) -> List[List[bool]] + def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]: li = torch.jit.annotate(List[List[bool]], x.tolist()) return li - def to_list_bool_3D(x): - # type: (torch.Tensor) -> List[List[List[bool]]] + def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]: li = torch.jit.annotate(List[List[List[bool]]], x.tolist()) return li @@ -1021,23 +987,19 @@ class TestList(JitTestCase): """ Int dtype unit tests. """ - def to_list_int_0D(x): - # type: (torch.Tensor) -> int + def to_list_int_0D(x: torch.Tensor) -> int: li = torch.jit.annotate(int, x.tolist()) return li - def to_list_int_1D(x): - # type: (torch.Tensor) -> List[int] + def to_list_int_1D(x: torch.Tensor) -> List[int]: li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_int_2D(x): - # type: (torch.Tensor) -> List[List[int]] + def to_list_int_2D(x: torch.Tensor) -> List[List[int]]: li = torch.jit.annotate(List[List[int]], x.tolist()) return li - def to_list_int_3D(x): - # type: (torch.Tensor) -> List[List[List[int]]] + def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]: li = torch.jit.annotate(List[List[List[int]]], x.tolist()) return li @@ -1058,23 +1020,19 @@ class TestList(JitTestCase): """ Float dtype unit tests. """ - def to_list_float_0D(x): - # type: (torch.Tensor) -> float + def to_list_float_0D(x: torch.Tensor) -> float: li = torch.jit.annotate(float, x.tolist()) return li - def to_list_float_1D(x): - # type: (torch.Tensor) -> List[float] + def to_list_float_1D(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li - def to_list_float_2D(x): - # type: (torch.Tensor) -> List[List[float]] + def to_list_float_2D(x: torch.Tensor) -> List[List[float]]: li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_float_3D(x): - # type: (torch.Tensor) -> List[List[List[float]]] + def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]: li = torch.jit.annotate(List[List[List[float]]], x.tolist()) return li @@ -1099,28 +1057,23 @@ class TestList(JitTestCase): - type annotation with the wrong dimension - type annotation with scalar type that doesn't match the input scalar type """ - def to_list_missing_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]: li = x.tolist() return li - def to_list_incorrect_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(float, x.tolist()) return li - def to_list_unsupported_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[str], x.tolist()) return li - def to_list_type_annotation_wrong_dim(x): - # type: (torch.Tensor) -> List[List[float]] + def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]: li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_type_annotation_incorrect_scalar_type(x): - # type: (torch.Tensor) -> List[float] + def to_list_type_annotation_incorrect_scalar_type(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1164,18 +1117,15 @@ class TestList(JitTestCase): if not torch.cuda.is_available() or torch.cuda.device_count() == 0: self.skipTest("CUDA is not available") - def to_list_bool_1D(x): - # type: (torch.Tensor) -> List[bool] + def to_list_bool_1D(x: torch.Tensor) -> List[bool]: li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_int_1D(x): - # type: (torch.Tensor) -> List[int] + def to_list_int_1D(x: torch.Tensor) -> List[int]: li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_float_1D(x): - # type: (torch.Tensor) -> List[float] + def to_list_float_1D(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1187,8 +1137,7 @@ class TestList(JitTestCase): 5, dtype=torch.double).cuda(),)) def test_no_element_type_annotation(self): - def fn_with_comment(x): - # type: (torch.Tensor) -> List + def fn_with_comment(x: torch.Tensor) -> List: a: List = x.tolist() return a @@ -1230,8 +1179,7 @@ class TestDict(JitTestCase): def inputs(): return {'hi': 2, 'bye': 3} - def fn(x): - # type: (Dict[str, int]) -> Dict[str, int] + def fn(x: Dict[str, int]) -> Dict[str, int]: del x['hi'] return x @@ -1247,8 +1195,7 @@ class TestDict(JitTestCase): def test_keys(self): @torch.jit.script - def keys(x): - # type: (Dict[str, Tensor]) -> List[str] + def keys(x: Dict[str, Tensor]) -> List[str]: return list(x.keys()) self.assertEqual(set(keys(self.dict())), set(self.dict().keys())) @@ -1263,30 +1210,26 @@ class TestDict(JitTestCase): def test_values(self): @torch.jit.script - def values(x): - # type: (Dict[str, Tensor]) -> List[Tensor] + def values(x: Dict[str, Tensor]) -> List[Tensor]: return list(x.values()) the_dict = self.dict() self.assertEqual(set(values(the_dict)), set(the_dict.values())) def test_len(self): - def length(x): - # type: (Dict[str, Tensor]) -> int + def length(x: Dict[str, Tensor]) -> int: return len(x) self.checkScript(length, (self.dict(),)) def test_copy(self): - def func(x): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]: return x.copy() self.checkScript(func, (self.dict(),)) def test_items(self): - def func(x): - # type: (Dict[str, Tensor]) -> List[Tuple[str, Tensor]] + def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: return x.items() # The value returned by Python is in arbitrary order, so we can't use @@ -1301,8 +1244,7 @@ class TestDict(JitTestCase): self.assertTrue(item in script_out) def test_pop(self): - def pop(x, key): - # type: (Dict[str, Tensor], str) -> Tuple[Tensor, Dict[str, Tensor]] + def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]: return x.pop(key), x # checkScript doesn't copy the inputs, so we can't use it since this mutates @@ -1318,16 +1260,14 @@ class TestDict(JitTestCase): torch.jit.script(pop)(self.dict(), 'x') - def default_pop(x, key, default): - # type: (Dict[str, Tensor], str, Tensor) -> Tuple[Tensor, Dict[str, Tensor]] + def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: return x.pop(key, default), x tester(default_pop, 'a', torch.randn(2, 2)) tester(default_pop, 'x', torch.randn(2, 2)) def test_setdefault(self): - def setdefault(x, key, default): - # type: (Dict[str, Tensor], str, Tensor) -> Dict[str, Tensor] + def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]: x.setdefault(key, default) return x @@ -1335,8 +1275,7 @@ class TestDict(JitTestCase): self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2))) def test_update(self): - def update(a, b): - # type: (Dict[str, Tensor], Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]] + def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: a.update(b) return a, b @@ -1353,8 +1292,7 @@ class TestDict(JitTestCase): self.checkScript(foo, ()) def test_aug_assign(self): - def aug_assign_dict_tensor(a): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]: a['a'] += 1 a['b'] -= 12 a['c'] *= 122 @@ -1362,8 +1300,7 @@ class TestDict(JitTestCase): a['c'] %= 2 return a - def aug_assign_dict_prim(a): - # type: (Dict[str, float]) -> Dict[str, float] + def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: a['a'] += 3.4 a['b'] -= 2.4 a['c'] *= 3.0 @@ -1376,8 +1313,7 @@ class TestDict(JitTestCase): def test_popitem(self): @torch.jit.script - def popitem(x): - # type: (Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]] + def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]: item = x.popitem() return item, x @@ -1395,65 +1331,56 @@ class TestDict(JitTestCase): self.assertTrue(isinstance(script_out[0][1], torch.Tensor)) def test_clear(self): - def clear(x): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]: x.clear() return x self.checkScript(clear, (self.dict(),)) def test_get(self): - def get(x, key): - # type: (Dict[str, Tensor], str) -> Optional[Tensor] + def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: return x.get(key) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) - def get_default(x, key): - # type: (Dict[str, Tensor], str) -> Optional[Tensor] + def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: return x.get(key, torch.randn(2, 2)) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) def test_get_boolkey(self): - def get(x, key): - # type: (Dict[bool, int], bool) -> Optional[int] + def get(x: Dict[bool, int], key: bool) -> Optional[int]: return x.get(key) self.checkScript(get, (self.dict_bool(), True)) self.checkScript(get, (self.dict_bool(), False)) - def get_default(x, key): - # type: (Dict[bool, int], bool) -> int + def get_default(x: Dict[bool, int], key: bool) -> int: return x.get(key, 42) self.checkScript(get_default, (self.dict_bool(), True)) self.checkScript(get_default, (self.dict_bool(), False)) def test_basic(self): - def simple(x): - # type: (Dict[str, int]) -> Dict[str, int] + def simple(x: Dict[str, int]) -> Dict[str, int]: return x self.checkScript(simple, ({'item': 20, 'other_item': 120},)) - def index(x): - # type: (Dict[str, int]) -> int + def index(x: Dict[str, int]) -> int: return x['item'] self.checkScript(index, ({'item': 20, 'other_item': 120},)) - def type_default(): - # type: () -> Dict[str, Tensor] + def type_default() -> Dict[str, Tensor]: return {} self.checkScript(type_default, ()) @torch.jit.script - def missing_index(x): - # type: (Dict[str, int]) -> int + def missing_index(x: Dict[str, int]) -> int: return x['dne'] with self.assertRaisesRegex(RuntimeError, "KeyError"): @@ -1475,16 +1402,14 @@ class TestDict(JitTestCase): ''')) self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3()) - def list_of_dicts(): - # type: () -> List[Dict[str, Tensor]] + def list_of_dicts() -> List[Dict[str, Tensor]]: return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}] self.checkScript(list_of_dicts, ()) def test_mutability(self): @torch.jit.script - def fn(): - # type: () -> Dict[str, int] + def fn() -> Dict[str, int]: a = torch.jit.annotate(Dict[str, int], {}) a['ok'] = 10 return a @@ -1494,14 +1419,12 @@ class TestDict(JitTestCase): def test_key_type(self): with self.assertRaisesRegex(RuntimeError, "but instead found type"): @torch.jit.script - def fn(a): - # type: (Dict[str, int]) -> int + def fn(a: Dict[str, int]) -> int: return a[None] def test_loop(self): @torch.jit.script - def fn(x): - # type: (int) -> Dict[str, int] + def fn(x: int) -> Dict[str, int]: a = torch.jit.annotate(Dict[str, int], {}) for i in range(x): a['ok'] = i @@ -1520,16 +1443,14 @@ class TestDict(JitTestCase): self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) def test_membership(self): - def fn(x, y): - # type: (Dict[int, int], int) -> int + def fn(x: Dict[int, int], y: int) -> int: return x.get(y, 3) d = {1: 2, 3: 4} self.checkScript(fn, (d, 3)) self.checkScript(fn, (d, 2)) - def optional(x, y): - # type: (Dict[int, int], int) -> bool + def optional(x: Dict[int, int], y: int) -> bool: res = x.get(y) return res is None @@ -1538,18 +1459,15 @@ class TestDict(JitTestCase): with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"): @torch.jit.script - def bad_types(x, y): - # type: (Dict[int, int], int) -> int + def bad_types(x: Dict[int, int], y: int) -> int: return x.get(y) # noqa: T484 def test_dict_to_python(self): @torch.jit.ignore - def python_lookup(my_dict, keys): - # type: (Dict[str, int], List[str]) -> List[int] + def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]: return [my_dict[k] for k in keys] - def fn(my_dict, keys): - # type: (Dict[str, int], List[str]) -> List[int] + def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]: return python_lookup(my_dict, keys) a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} @@ -1601,8 +1519,7 @@ class TestDict(JitTestCase): key and value types produces an error. """ # This function uses a type comment. - def fn_with_comment(input): - # type: (Dict) -> Any + def fn_with_comment(input: Dict) -> Any: return input # This function uses Python3 style type annotations. diff --git a/test/jit/test_module_interface.py b/test/jit/test_module_interface.py index 8d56f770b40a..d0626b1068b4 100644 --- a/test/jit/test_module_interface.py +++ b/test/jit/test_module_interface.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import os import sys +from torch import Tensor from torch.testing._internal.jit_utils import JitTestCase # Make the helper files in test/ importable @@ -22,36 +23,30 @@ class OrigModule(nn.Module): def __init__(self): super(OrigModule, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 - def two(self, input): - # type: (Tensor) -> Tensor + def two(self, input: Tensor) -> Tensor: return input + 2 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewModule(nn.Module): def __init__(self): super(NewModule, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestModuleInterface(JitTestCase): def test_not_submodule_interface_call(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestNotModuleInterfaceCall(nn.Module): @@ -61,8 +56,7 @@ class TestModuleInterface(JitTestCase): super(TestNotModuleInterfaceCall, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.two(input) with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"): @@ -72,64 +66,51 @@ class TestModuleInterface(JitTestCase): global OneTwoModule, OneTwoClass @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass @torch.jit.interface class OneTwoClass(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass class FooMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x * y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 / x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) @torch.jit.export - def forward2(self, x): - # type: (Tensor) -> Tensor + def forward2(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) + 1 def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) - def use_class_interface(mod_list, x): - # type: (List[OneTwoClass], Tensor) -> Tensor + def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor: return mod_list[0].two(x) + mod_list[1].one(x, x) scripted_foo_mod = torch.jit.script(FooMod()) @@ -139,8 +120,7 @@ class TestModuleInterface(JitTestCase): self.checkScript(use_class_interface, ([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),)) - def call_module_interface_on_other_method(mod_interface, x): - # type: (OneTwoModule, Tensor) -> Tensor + def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor: return mod_interface.forward2(x) # ensure error out when we call the module on the method other than the interface specified. @@ -178,35 +158,28 @@ class TestModuleInterface(JitTestCase): global OneTwoModule @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass @torch.jit.script - def as_module_interface(x): - # type: (OneTwoModule) -> OneTwoModule + def as_module_interface(x: OneTwoModule) -> OneTwoModule: return x @torch.jit.script class Foo(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) # check class object is not a subtype of module interface @@ -214,12 +187,10 @@ class TestModuleInterface(JitTestCase): as_module_interface(Foo()) class WrongMod(nn.Module): - def two(self, x): - # type: (int) -> int + def two(self, x: int) -> int: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return x + torch.randn(3, self.two(3)) scripted_wrong_mod = torch.jit.script(WrongMod()) @@ -270,19 +241,16 @@ class TestModuleInterface(JitTestCase): with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"): @torch.jit.interface class InheritMod(nn.ReLU): - def three(self, x): - # type: (Tensor) -> Tensor + def three(self, x: Tensor) -> Tensor: return 3 * x def test_module_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): @@ -292,8 +260,7 @@ class TestModuleInterface(JitTestCase): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -311,20 +278,17 @@ class TestModuleInterface(JitTestCase): def test_module_swap_wrong_module(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class NewModuleWrong(nn.Module): def __init__(self): super(NewModuleWrong, self).__init__() - def forward(self, input): - # type: (int) -> int + def forward(self, input: int) -> int: return input + 1 class TestModule(nn.Module): @@ -334,8 +298,7 @@ class TestModuleInterface(JitTestCase): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -346,12 +309,10 @@ class TestModuleInterface(JitTestCase): def test_module_swap_no_lazy_compile(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): @@ -361,20 +322,17 @@ class TestModuleInterface(JitTestCase): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) class NewModuleMethodNotLazyCompile(nn.Module): def __init__(self): super(NewModuleMethodNotLazyCompile, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod = torch.jit.script(TestModule()) @@ -388,12 +346,10 @@ class TestModuleInterface(JitTestCase): super(NewModuleMethodManualExport, self).__init__() @torch.jit.export - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport()) @@ -407,8 +363,7 @@ class TestModuleInterface(JitTestCase): super(TestNoModuleInterface, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod(input) scripted_no_module_interface = torch.jit.script(TestNoModuleInterface()) @@ -423,12 +378,10 @@ class TestModuleInterface(JitTestCase): def test_script_module_as_interface_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class OrigScriptModule(torch.jit.ScriptModule): @@ -436,13 +389,11 @@ class TestModuleInterface(JitTestCase): super(OrigScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 @torch.jit.script_method - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewScriptModule(torch.jit.ScriptModule): @@ -450,13 +401,11 @@ class TestModuleInterface(JitTestCase): super(NewScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 @torch.jit.script_method - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestNNModuleWithScriptModule(nn.Module): @@ -466,8 +415,7 @@ class TestModuleInterface(JitTestCase): super(TestNNModuleWithScriptModule, self).__init__() self.proxy_mod = OrigScriptModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) input = torch.randn(3, 4) @@ -498,8 +446,7 @@ class TestModuleInterface(JitTestCase): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> int + def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): @@ -546,8 +493,7 @@ class TestModuleInterface(JitTestCase): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> int + def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): @@ -590,8 +536,7 @@ class TestModuleInterface(JitTestCase): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -636,8 +581,7 @@ class TestModuleInterface(JitTestCase): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -679,8 +623,7 @@ class TestModuleInterface(JitTestCase): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -714,8 +657,7 @@ class TestModuleInterface(JitTestCase): def test_module_apis_interface(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestModule(nn.Module): diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index a84d9d7256b8..d18c4f6a3dab 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -2,10 +2,11 @@ import os import sys import typing import typing_extensions -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Tuple import torch import torch.nn as nn +from torch import Tensor from torch.testing import FileCheck from collections import OrderedDict @@ -284,8 +285,7 @@ class TestRecursiveScript(JitTestCase): test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))) def test_class_compile(self): - def other_fn(a, b): - # type: (int, Tensor) -> Tensor + def other_fn(a: int, b: Tensor) -> Tensor: return a * b class B(object): @@ -307,8 +307,7 @@ class TestRecursiveScript(JitTestCase): self.checkModule(N(), (torch.randn(2, 2),)) def test_error_stack(self): - def d(x): - # type: (int) -> int + def d(x: int) -> int: return x + 10 def c(x): @@ -331,8 +330,7 @@ class TestRecursiveScript(JitTestCase): checker.run(str(e)) def test_error_stack_module(self): - def d(x): - # type: (int) -> int + def d(x: int) -> int: return x + 10 def c(x): @@ -565,8 +563,7 @@ class TestRecursiveScript(JitTestCase): self.a = 4 self.inner = Inner2() - def __setstate__(self, obj): - # type: (Tuple[int, Inner2]) -> None + def __setstate__(self, obj: Tuple[int, Inner2]) -> None: a, inner = obj self.a = a self.inner = inner diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index eb14ca8350af..5136e50144f1 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -1,12 +1,14 @@ -import os -import io -import pathlib -import sys -import random -import torch from itertools import product as product -from torch.testing._internal.common_utils import TemporaryFileName from typing import NamedTuple, Optional +import io +import os +import pathlib +import random +import sys + +from torch import Tensor +from torch.testing._internal.common_utils import TemporaryFileName +import torch # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -680,8 +682,7 @@ class TestSaveLoad(JitTestCase): """ @torch.jit.interface class MyInterface(object): - def bar(self, x): - # type: (Tensor) -> Tensor + def bar(self, x: Tensor) -> Tensor: pass @torch.jit.script @@ -711,8 +712,7 @@ class TestSaveLoad(JitTestCase): @torch.jit.interface class MyInterface(object): - def not_bar(self, x): - # type: (Tensor) -> Tensor + def not_bar(self, x: Tensor) -> Tensor: pass @torch.jit.script # noqa: F811 @@ -767,8 +767,7 @@ class TestSaveLoad(JitTestCase): @torch.jit.interface class MyInterface(object): - def bar(self, x): - # type: (Tensor) -> Tensor + def bar(self, x: Tensor) -> Tensor: pass @torch.jit.script @@ -809,8 +808,7 @@ class TestSaveLoad(JitTestCase): @torch.jit.interface class MyInterface(object): - def not_bar(self, x): - # type: (Tensor) -> Tensor + def not_bar(self, x: Tensor) -> Tensor: pass @torch.jit.script # noqa F811 diff --git a/test/jit/test_with.py b/test/jit/test_with.py index ffd0631639f6..f958dc46c39a 100644 --- a/test/jit/test_with.py +++ b/test/jit/test_with.py @@ -1,7 +1,7 @@ import os import sys -from typing import Any +from typing import Any, List import torch from torch.testing._internal.jit_utils import JitTestCase @@ -50,8 +50,7 @@ class TestWith(JitTestCase): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x): - # type: (Tensor) -> Tensor + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) @@ -62,8 +61,7 @@ class TestWith(JitTestCase): y *= c.count return y - def test_pass(x): - # type: (Tensor) -> Tensor + def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -77,8 +75,7 @@ class TestWith(JitTestCase): x *= c.count return x - def test_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. @@ -90,8 +87,7 @@ class TestWith(JitTestCase): x = y + y return x - def test_conditional_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. @@ -104,8 +100,7 @@ class TestWith(JitTestCase): x = y + y return x - def test_break(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. @@ -118,8 +113,7 @@ class TestWith(JitTestCase): return x - def test_continue(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. @@ -132,8 +126,7 @@ class TestWith(JitTestCase): return x - def test_serial(x): - # type: (Tensor) -> Tensor + def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ @@ -147,8 +140,7 @@ class TestWith(JitTestCase): return y - def test_nested(x): - # type: (Tensor) -> Tensor + def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ @@ -162,8 +154,7 @@ class TestWith(JitTestCase): return y - def test_combined(x): - # type: (Tensor) -> Tensor + def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ @@ -215,8 +206,7 @@ class TestWith(JitTestCase): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x): - # type: (Tensor) -> Tensor + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) @@ -227,8 +217,7 @@ class TestWith(JitTestCase): y *= c.count return y - def test_pass(x): - # type: (Tensor) -> Tensor + def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -242,8 +231,7 @@ class TestWith(JitTestCase): x *= c.count return x - def test_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. @@ -255,8 +243,7 @@ class TestWith(JitTestCase): x = y + y return x - def test_conditional_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. @@ -269,8 +256,7 @@ class TestWith(JitTestCase): x = y + y return x - def test_break(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. @@ -283,8 +269,7 @@ class TestWith(JitTestCase): return x - def test_continue(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. @@ -297,8 +282,7 @@ class TestWith(JitTestCase): return x - def test_serial(x): - # type: (Tensor) -> Tensor + def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ @@ -312,8 +296,7 @@ class TestWith(JitTestCase): return y - def test_nested(x): - # type: (Tensor) -> Tensor + def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ @@ -327,8 +310,7 @@ class TestWith(JitTestCase): return y - def test_combined(x): - # type: (Tensor) -> Tensor + def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ @@ -381,13 +363,11 @@ class TestWith(JitTestCase): self.count.sub_(0.3) @torch.jit.script - def method_that_raises(): - # type: () -> Tensor + def method_that_raises() -> torch.Tensor: raise Exception("raised exception") @torch.jit.script - def test_exception(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a with-statement. """ @@ -397,8 +377,7 @@ class TestWith(JitTestCase): return x @torch.jit.script - def test_exception_nested(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a nested with-statement. """ @@ -409,8 +388,7 @@ class TestWith(JitTestCase): return x @torch.jit.script - def with_that_raises(c): - # type: (Context) -> Tensor + def with_that_raises(c: Context) -> torch.Tensor: a = torch.tensor([1]) with c as _: @@ -419,8 +397,7 @@ class TestWith(JitTestCase): return a @torch.jit.script - def test_exception_fn_call(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while there are active with-statements in two different frames. @@ -506,29 +483,25 @@ class TestWith(JitTestCase): def __exit__(self, type: Any, value: int, tb: int): pass - def test_no_enter_no_exit(x, c): - # type: (Tensor, NoEnterNoExit) -> Tensor + def test_no_enter_no_exit(x: torch.Tensor, c: NoEnterNoExit) -> torch.Tensor: with c as _: pass return x - def test_bad_enter(x, c): - # type: (Tensor, BadEnter) -> Tensor + def test_bad_enter(x: torch.Tensor, c: BadEnter) -> torch.Tensor: with c as _: pass return x - def test_bad_exit(x, c): - # type: (Tensor, BadExit) -> Tensor + def test_bad_exit(x: torch.Tensor, c: BadExit) -> torch.Tensor: with c as _: pass return x - def test_exit_incorrect_types(x, c): - # type: (Tensor, ExitIncorrectTypes) -> Tensor + def test_exit_incorrect_types(x: torch.Tensor, c: ExitIncorrectTypes) -> torch.Tensor: with c as _: pass @@ -565,8 +538,7 @@ class TestWith(JitTestCase): """ # Basic no_grad test. - def test_no_grad(x, y): - # type: (Tensor, Tensor) -> Tensor + def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = x + y @@ -583,8 +555,7 @@ class TestWith(JitTestCase): # Test assignment of a grad-less Tensor to a Tensor with gradients # in a no_grad block. - def test_no_grad_assignment(x, y): - # type: (Tensor, Tensor) -> Tensor + def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): x[0] = y @@ -603,13 +574,11 @@ class TestWith(JitTestCase): super().__init__() @torch.jit.ignore - def adder(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: w = x + y return w - def forward(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = self.adder(x, y) @@ -625,8 +594,7 @@ class TestWith(JitTestCase): Check that torch.autograd.profiler.record_function context manager is torchscriptable. """ - def with_rf(x, y): - # type: (Tensor, Tensor) -> Tensor + def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.autograd.profiler.record_function("foo"): # Nested record_function. with torch.autograd.profiler.record_function("nested"):