mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert D25717504: Clean up some type annotations in test/jit
Test Plan: revert-hammer
Differential Revision:
D25717504 (a4f30d48d8
)
Original commit changeset: 9a83c44db02e
fbshipit-source-id: e6e3a83bed22701d8125f5a293dfcd5093c1a2cd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f9f758e349
commit
1bb7d8ff93
@ -41,7 +41,8 @@ class TestAsync(JitTestCase):
|
||||
|
||||
def test_async_parsing(self):
|
||||
@torch.jit.script
|
||||
def foo(x: Tensor) -> List[Tensor]:
|
||||
def foo(x):
|
||||
# type: (Tensor) -> List[Tensor]
|
||||
return [torch.neg(x), x.t()]
|
||||
|
||||
@torch.jit.script
|
||||
@ -256,7 +257,8 @@ class TestAsync(JitTestCase):
|
||||
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
|
||||
future1 = torch.jit._fork(self.traced, x)
|
||||
future2 = torch.jit._fork(torch.neg, x)
|
||||
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import sys
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import Dict, List
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@ -78,7 +78,8 @@ class TestBuiltins(JitTestCase):
|
||||
torch.jit.script(Mod())
|
||||
|
||||
def test_del(self):
|
||||
def fn(x: List[int]) -> List[int]:
|
||||
def fn(x):
|
||||
# type: (List[int]) -> List[int]
|
||||
a = x * 2
|
||||
del a
|
||||
return x
|
||||
@ -108,14 +109,16 @@ class TestBuiltins(JitTestCase):
|
||||
return a
|
||||
|
||||
def test_del_multiple_operands(self):
|
||||
def fn(x: List[int]) -> List[int]:
|
||||
def fn(x):
|
||||
# type: (List[int]) -> List[int]
|
||||
a, b, c = x[0], x[1], x[2]
|
||||
del a, b, c
|
||||
return x
|
||||
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
|
||||
def del_list_multiple_operands(x: List[int]) -> List[int]:
|
||||
def del_list_multiple_operands(x):
|
||||
# type: (List[int]) -> List[int]
|
||||
del x[0], x[1]
|
||||
return x
|
||||
|
||||
@ -123,7 +126,8 @@ class TestBuiltins(JitTestCase):
|
||||
jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2])
|
||||
self.assertEquals(py_out, jit_out)
|
||||
|
||||
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
|
||||
def del_dict_multiple_operands(x):
|
||||
# type: (Dict[str, int]) -> Dict[str, int]
|
||||
del x['hi'], x['there']
|
||||
return x
|
||||
|
||||
|
@ -20,19 +20,22 @@ if __name__ == '__main__':
|
||||
|
||||
class TestList(JitTestCase):
|
||||
def test_in_check(self):
|
||||
def int_in(x: List[int]) -> bool:
|
||||
def int_in(x):
|
||||
# type: (List[int]) -> bool
|
||||
return 2 in x
|
||||
|
||||
self.checkScript(int_in, ([1, 2, 3],))
|
||||
self.checkScript(int_in, ([1, 3, 3],))
|
||||
|
||||
def float_in(x: List[float]) -> bool:
|
||||
def float_in(x):
|
||||
# type: (List[float]) -> bool
|
||||
return 2. in x
|
||||
|
||||
self.checkScript(float_in, ([1., 2., 3.],))
|
||||
self.checkScript(float_in, ([1., 3., 3.],))
|
||||
|
||||
def str_in(x: List[str]) -> bool:
|
||||
def str_in(x):
|
||||
# type: (List[str]) -> bool
|
||||
return 'hi' in x
|
||||
|
||||
self.checkScript(str_in, (['not', 'here'],))
|
||||
@ -97,7 +100,8 @@ class TestList(JitTestCase):
|
||||
def inputs():
|
||||
return [1, 2, 3, 4]
|
||||
|
||||
def fn(x: List[int]) -> List[int]:
|
||||
def fn(x):
|
||||
# type: (List[int]) -> List[int]
|
||||
del x[1]
|
||||
return x
|
||||
|
||||
@ -110,7 +114,8 @@ class TestList(JitTestCase):
|
||||
self.assertEqual(torch.jit.script(fn)(inputs()), python_out)
|
||||
|
||||
@torch.jit.script
|
||||
def fn2(x: List[int]) -> List[int]:
|
||||
def fn2(x):
|
||||
# type: (List[int]) -> List[int]
|
||||
del x[100]
|
||||
return x
|
||||
|
||||
@ -119,7 +124,8 @@ class TestList(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "deletion at a single index"):
|
||||
@torch.jit.script
|
||||
def fn(x: List[int]) -> List[int]:
|
||||
def fn(x):
|
||||
# type: (List[int]) -> List[int]
|
||||
del x[1:3]
|
||||
return x
|
||||
|
||||
@ -143,19 +149,23 @@ class TestList(JitTestCase):
|
||||
FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)
|
||||
|
||||
def test_min_bool_list(self):
|
||||
def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]:
|
||||
def jit_min_list(a, b):
|
||||
# type: (List[bool], List[bool]) -> List[bool]
|
||||
return min(a, b)
|
||||
|
||||
self.checkScript(jit_min_list, ([True, False], [False, True]))
|
||||
|
||||
def test_min_max_list(self):
|
||||
def jit_min_list(a: List[int], b: List[int]) -> List[int]:
|
||||
def jit_min_list(a, b):
|
||||
# type: (List[int], List[int]) -> List[int]
|
||||
return min(a, b)
|
||||
|
||||
def jit_min_list_float(a: List[float], b: List[float]) -> List[float]:
|
||||
def jit_min_list_float(a, b):
|
||||
# type: (List[float], List[float]) -> List[float]
|
||||
return min(a, b)
|
||||
|
||||
def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
|
||||
def jit_min_list_bool(a, b):
|
||||
# type: (List[bool], List[bool]) -> List[bool]
|
||||
return min(a, b)
|
||||
|
||||
def run_tests(func, a, b):
|
||||
@ -176,13 +186,16 @@ class TestList(JitTestCase):
|
||||
[False, True], [False, False, True], [False, False, False]]
|
||||
run_tests(jit_min_list_bool, args_left_bool, args_right_bool)
|
||||
|
||||
def jit_max_list(a: List[int], b: List[int]) -> List[int]:
|
||||
def jit_max_list(a, b):
|
||||
# type: (List[int], List[int]) -> List[int]
|
||||
return max(a, b)
|
||||
|
||||
def jit_max_list_float(a: List[float], b: List[float]) -> List[float]:
|
||||
def jit_max_list_float(a, b):
|
||||
# type: (List[float], List[float]) -> List[float]
|
||||
return max(a, b)
|
||||
|
||||
def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
|
||||
def jit_max_list_bool(a, b):
|
||||
# type: (List[bool], List[bool]) -> List[bool]
|
||||
return max(a, b)
|
||||
|
||||
args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]]
|
||||
@ -352,7 +365,8 @@ class TestList(JitTestCase):
|
||||
t2 = scope['func']()
|
||||
self.assertEqual(t1, t2)
|
||||
|
||||
def test_fail(x: List[Tensor]) -> List[Tensor]:
|
||||
def test_fail(x):
|
||||
# type: (List[Tensor]) -> List[Tensor]
|
||||
x.sort()
|
||||
return x
|
||||
|
||||
@ -458,7 +472,8 @@ class TestList(JitTestCase):
|
||||
self.checkScript(test_append, ())
|
||||
|
||||
def test_comprehensions_basic(self):
|
||||
def comp(l: List[int]) -> List[int]:
|
||||
def comp(l):
|
||||
# type: (List[int]) -> List[int]
|
||||
|
||||
n = [x * 3 for x in l]
|
||||
return n
|
||||
@ -467,7 +482,8 @@ class TestList(JitTestCase):
|
||||
self.checkScript(comp, ([1, 2, 3],))
|
||||
|
||||
def test_comprehensions_basic_float(self):
|
||||
def comp(l: List[float]) -> List[float]:
|
||||
def comp(l):
|
||||
# type: (List[float]) -> List[float]
|
||||
|
||||
n = [x * 3 for x in l]
|
||||
return n
|
||||
@ -476,7 +492,8 @@ class TestList(JitTestCase):
|
||||
|
||||
def test_comprehensions_two_comps(self):
|
||||
@torch.jit.script
|
||||
def comp(l1: List[int], l2: List[int]) -> List[int]:
|
||||
def comp(l1, l2):
|
||||
# type: (List[int], List[int]) -> List[int]
|
||||
|
||||
n = [x * 3 for x in l1]
|
||||
n2 = [x + 2 for x in l2]
|
||||
@ -485,7 +502,8 @@ class TestList(JitTestCase):
|
||||
self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7])
|
||||
|
||||
def test_comprehension_out_type_not_in_type(self):
|
||||
def list_cast() -> int:
|
||||
def list_cast():
|
||||
# type: () -> int
|
||||
li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]]
|
||||
return li[0] + li[1] + li[2]
|
||||
|
||||
@ -495,13 +513,15 @@ class TestList(JitTestCase):
|
||||
def test_func(fn, inputs):
|
||||
self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs))
|
||||
|
||||
def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]:
|
||||
def foo(names, results):
|
||||
# type: (List[int], List[int]) -> List[Tuple[int, int]]
|
||||
return [(k + 5, v - 2) for k, v in zip(names, results)]
|
||||
|
||||
test_func(foo, ([1, 2, 4], [4, 7, 9]))
|
||||
test_func(foo, ([5], [4, 7, 9]))
|
||||
|
||||
def fn(x: int) -> List[int]:
|
||||
def fn(x):
|
||||
# type: (int) -> List[int]
|
||||
return [i for i in range(x)] # noqa: C416
|
||||
|
||||
test_func(fn, (9,))
|
||||
@ -581,7 +601,8 @@ class TestList(JitTestCase):
|
||||
|
||||
def test_mutable_list_function_inline(self):
|
||||
@torch.jit.script
|
||||
def bar(y: List[int]) -> None:
|
||||
def bar(y):
|
||||
# type: (List[int]) -> None
|
||||
y.append(4)
|
||||
|
||||
@torch.jit.script
|
||||
@ -867,7 +888,8 @@ class TestList(JitTestCase):
|
||||
|
||||
def test_extend_list_mutable(self):
|
||||
@torch.jit.script
|
||||
def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]:
|
||||
def extend_list(a, b):
|
||||
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
|
||||
|
||||
a.extend(b)
|
||||
return a
|
||||
@ -878,7 +900,8 @@ class TestList(JitTestCase):
|
||||
|
||||
def test_extend_list_immutable(self):
|
||||
@torch.jit.script
|
||||
def extend_list(a: List[int], b: List[int]) -> List[int]:
|
||||
def extend_list(a, b):
|
||||
# type: (List[int], List[int]) -> List[int]
|
||||
|
||||
a.extend(b)
|
||||
return a
|
||||
@ -889,7 +912,8 @@ class TestList(JitTestCase):
|
||||
|
||||
def test_copy_list_mutable(self):
|
||||
@torch.jit.script
|
||||
def copy_list(a: List[Tensor]) -> List[Tensor]:
|
||||
def copy_list(a):
|
||||
# type: (List[Tensor]) -> List[Tensor]
|
||||
return a.copy()
|
||||
|
||||
for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
|
||||
@ -897,29 +921,36 @@ class TestList(JitTestCase):
|
||||
|
||||
def test_copy_list_immutable(self):
|
||||
@torch.jit.script
|
||||
def copy_list(a: List[int]) -> List[int]:
|
||||
def copy_list(a):
|
||||
# type: (List[int]) -> List[int]
|
||||
return a.copy()
|
||||
|
||||
for l in [[], [1], [1, 2, 3]]:
|
||||
self.assertEqual(copy_list(l), l)
|
||||
|
||||
def test_min_max_single_list(self):
|
||||
def min_intlist(li: List[int]) -> int:
|
||||
def min_intlist(li):
|
||||
# type: (List[int]) -> int
|
||||
return min(li)
|
||||
|
||||
def max_intlist(li: List[int]) -> int:
|
||||
def max_intlist(li):
|
||||
# type: (List[int]) -> int
|
||||
return max(li)
|
||||
|
||||
def min_boollist(li: List[bool]) -> bool:
|
||||
def min_boollist(li):
|
||||
# type: (List[bool]) -> bool
|
||||
return min(li)
|
||||
|
||||
def max_boollist(li: List[bool]) -> bool:
|
||||
def max_boollist(li):
|
||||
# type: (List[bool]) -> bool
|
||||
return max(li)
|
||||
|
||||
def min_floatlist(li: List[float]) -> float:
|
||||
def min_floatlist(li):
|
||||
# type: (List[float]) -> float
|
||||
return min(li)
|
||||
|
||||
def max_floatlist(li: List[float]) -> float:
|
||||
def max_floatlist(li):
|
||||
# type: (List[float]) -> float
|
||||
return max(li)
|
||||
|
||||
|
||||
@ -949,19 +980,23 @@ class TestList(JitTestCase):
|
||||
"""
|
||||
Boolean dtype unit tests.
|
||||
"""
|
||||
def to_list_bool_0D(x: torch.Tensor) -> bool:
|
||||
def to_list_bool_0D(x):
|
||||
# type: (torch.Tensor) -> bool
|
||||
li = torch.jit.annotate(bool, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
|
||||
def to_list_bool_1D(x):
|
||||
# type: (torch.Tensor) -> List[bool]
|
||||
li = torch.jit.annotate(List[bool], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]:
|
||||
def to_list_bool_2D(x):
|
||||
# type: (torch.Tensor) -> List[List[bool]]
|
||||
li = torch.jit.annotate(List[List[bool]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]:
|
||||
def to_list_bool_3D(x):
|
||||
# type: (torch.Tensor) -> List[List[List[bool]]]
|
||||
li = torch.jit.annotate(List[List[List[bool]]], x.tolist())
|
||||
return li
|
||||
|
||||
@ -986,19 +1021,23 @@ class TestList(JitTestCase):
|
||||
"""
|
||||
Int dtype unit tests.
|
||||
"""
|
||||
def to_list_int_0D(x: torch.Tensor) -> int:
|
||||
def to_list_int_0D(x):
|
||||
# type: (torch.Tensor) -> int
|
||||
li = torch.jit.annotate(int, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_int_1D(x: torch.Tensor) -> List[int]:
|
||||
def to_list_int_1D(x):
|
||||
# type: (torch.Tensor) -> List[int]
|
||||
li = torch.jit.annotate(List[int], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_int_2D(x: torch.Tensor) -> List[List[int]]:
|
||||
def to_list_int_2D(x):
|
||||
# type: (torch.Tensor) -> List[List[int]]
|
||||
li = torch.jit.annotate(List[List[int]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]:
|
||||
def to_list_int_3D(x):
|
||||
# type: (torch.Tensor) -> List[List[List[int]]]
|
||||
li = torch.jit.annotate(List[List[List[int]]], x.tolist())
|
||||
return li
|
||||
|
||||
@ -1019,19 +1058,23 @@ class TestList(JitTestCase):
|
||||
"""
|
||||
Float dtype unit tests.
|
||||
"""
|
||||
def to_list_float_0D(x: torch.Tensor) -> float:
|
||||
def to_list_float_0D(x):
|
||||
# type: (torch.Tensor) -> float
|
||||
li = torch.jit.annotate(float, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_float_1D(x: torch.Tensor) -> List[float]:
|
||||
def to_list_float_1D(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(List[float], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_float_2D(x: torch.Tensor) -> List[List[float]]:
|
||||
def to_list_float_2D(x):
|
||||
# type: (torch.Tensor) -> List[List[float]]
|
||||
li = torch.jit.annotate(List[List[float]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]:
|
||||
def to_list_float_3D(x):
|
||||
# type: (torch.Tensor) -> List[List[List[float]]]
|
||||
li = torch.jit.annotate(List[List[List[float]]], x.tolist())
|
||||
return li
|
||||
|
||||
@ -1056,23 +1099,28 @@ class TestList(JitTestCase):
|
||||
- type annotation with the wrong dimension
|
||||
- type annotation with scalar type that doesn't match the input scalar type
|
||||
"""
|
||||
def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]:
|
||||
def to_list_missing_type_annotation(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = x.tolist()
|
||||
return li
|
||||
|
||||
def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]:
|
||||
def to_list_incorrect_type_annotation(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(float, x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]:
|
||||
def to_list_unsupported_type_annotation(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(List[str], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]:
|
||||
def to_list_type_annotation_wrong_dim(x):
|
||||
# type: (torch.Tensor) -> List[List[float]]
|
||||
li = torch.jit.annotate(List[List[float]], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_type_annotation_incorrect_scalar_type(x: torch.Tensor) -> List[float]:
|
||||
def to_list_type_annotation_incorrect_scalar_type(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(List[float], x.tolist())
|
||||
return li
|
||||
|
||||
@ -1116,15 +1164,18 @@ class TestList(JitTestCase):
|
||||
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
|
||||
self.skipTest("CUDA is not available")
|
||||
|
||||
def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
|
||||
def to_list_bool_1D(x):
|
||||
# type: (torch.Tensor) -> List[bool]
|
||||
li = torch.jit.annotate(List[bool], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_int_1D(x: torch.Tensor) -> List[int]:
|
||||
def to_list_int_1D(x):
|
||||
# type: (torch.Tensor) -> List[int]
|
||||
li = torch.jit.annotate(List[int], x.tolist())
|
||||
return li
|
||||
|
||||
def to_list_float_1D(x: torch.Tensor) -> List[float]:
|
||||
def to_list_float_1D(x):
|
||||
# type: (torch.Tensor) -> List[float]
|
||||
li = torch.jit.annotate(List[float], x.tolist())
|
||||
return li
|
||||
|
||||
@ -1136,7 +1187,8 @@ class TestList(JitTestCase):
|
||||
5, dtype=torch.double).cuda(),))
|
||||
|
||||
def test_no_element_type_annotation(self):
|
||||
def fn_with_comment(x: torch.Tensor) -> List:
|
||||
def fn_with_comment(x):
|
||||
# type: (torch.Tensor) -> List
|
||||
a: List = x.tolist()
|
||||
return a
|
||||
|
||||
@ -1178,7 +1230,8 @@ class TestDict(JitTestCase):
|
||||
def inputs():
|
||||
return {'hi': 2, 'bye': 3}
|
||||
|
||||
def fn(x: Dict[str, int]) -> Dict[str, int]:
|
||||
def fn(x):
|
||||
# type: (Dict[str, int]) -> Dict[str, int]
|
||||
del x['hi']
|
||||
return x
|
||||
|
||||
@ -1194,7 +1247,8 @@ class TestDict(JitTestCase):
|
||||
|
||||
def test_keys(self):
|
||||
@torch.jit.script
|
||||
def keys(x: Dict[str, Tensor]) -> List[str]:
|
||||
def keys(x):
|
||||
# type: (Dict[str, Tensor]) -> List[str]
|
||||
return list(x.keys())
|
||||
|
||||
self.assertEqual(set(keys(self.dict())), set(self.dict().keys()))
|
||||
@ -1209,26 +1263,30 @@ class TestDict(JitTestCase):
|
||||
|
||||
def test_values(self):
|
||||
@torch.jit.script
|
||||
def values(x: Dict[str, Tensor]) -> List[Tensor]:
|
||||
def values(x):
|
||||
# type: (Dict[str, Tensor]) -> List[Tensor]
|
||||
return list(x.values())
|
||||
|
||||
the_dict = self.dict()
|
||||
self.assertEqual(set(values(the_dict)), set(the_dict.values()))
|
||||
|
||||
def test_len(self):
|
||||
def length(x: Dict[str, Tensor]) -> int:
|
||||
def length(x):
|
||||
# type: (Dict[str, Tensor]) -> int
|
||||
return len(x)
|
||||
|
||||
self.checkScript(length, (self.dict(),))
|
||||
|
||||
def test_copy(self):
|
||||
def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def func(x):
|
||||
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
|
||||
return x.copy()
|
||||
|
||||
self.checkScript(func, (self.dict(),))
|
||||
|
||||
def test_items(self):
|
||||
def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]:
|
||||
def func(x):
|
||||
# type: (Dict[str, Tensor]) -> List[Tuple[str, Tensor]]
|
||||
return x.items()
|
||||
|
||||
# The value returned by Python is in arbitrary order, so we can't use
|
||||
@ -1243,7 +1301,8 @@ class TestDict(JitTestCase):
|
||||
self.assertTrue(item in script_out)
|
||||
|
||||
def test_pop(self):
|
||||
def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]:
|
||||
def pop(x, key):
|
||||
# type: (Dict[str, Tensor], str) -> Tuple[Tensor, Dict[str, Tensor]]
|
||||
return x.pop(key), x
|
||||
|
||||
# checkScript doesn't copy the inputs, so we can't use it since this mutates
|
||||
@ -1259,14 +1318,16 @@ class TestDict(JitTestCase):
|
||||
torch.jit.script(pop)(self.dict(), 'x')
|
||||
|
||||
|
||||
def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
||||
def default_pop(x, key, default):
|
||||
# type: (Dict[str, Tensor], str, Tensor) -> Tuple[Tensor, Dict[str, Tensor]]
|
||||
return x.pop(key, default), x
|
||||
|
||||
tester(default_pop, 'a', torch.randn(2, 2))
|
||||
tester(default_pop, 'x', torch.randn(2, 2))
|
||||
|
||||
def test_setdefault(self):
|
||||
def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]:
|
||||
def setdefault(x, key, default):
|
||||
# type: (Dict[str, Tensor], str, Tensor) -> Dict[str, Tensor]
|
||||
x.setdefault(key, default)
|
||||
return x
|
||||
|
||||
@ -1274,7 +1335,8 @@ class TestDict(JitTestCase):
|
||||
self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2)))
|
||||
|
||||
def test_update(self):
|
||||
def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
|
||||
def update(a, b):
|
||||
# type: (Dict[str, Tensor], Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]
|
||||
a.update(b)
|
||||
return a, b
|
||||
|
||||
@ -1291,7 +1353,8 @@ class TestDict(JitTestCase):
|
||||
self.checkScript(foo, ())
|
||||
|
||||
def test_aug_assign(self):
|
||||
def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def aug_assign_dict_tensor(a):
|
||||
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
|
||||
a['a'] += 1
|
||||
a['b'] -= 12
|
||||
a['c'] *= 122
|
||||
@ -1299,7 +1362,8 @@ class TestDict(JitTestCase):
|
||||
a['c'] %= 2
|
||||
return a
|
||||
|
||||
def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]:
|
||||
def aug_assign_dict_prim(a):
|
||||
# type: (Dict[str, float]) -> Dict[str, float]
|
||||
a['a'] += 3.4
|
||||
a['b'] -= 2.4
|
||||
a['c'] *= 3.0
|
||||
@ -1312,7 +1376,8 @@ class TestDict(JitTestCase):
|
||||
|
||||
def test_popitem(self):
|
||||
@torch.jit.script
|
||||
def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]:
|
||||
def popitem(x):
|
||||
# type: (Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]
|
||||
item = x.popitem()
|
||||
return item, x
|
||||
|
||||
@ -1330,56 +1395,65 @@ class TestDict(JitTestCase):
|
||||
self.assertTrue(isinstance(script_out[0][1], torch.Tensor))
|
||||
|
||||
def test_clear(self):
|
||||
def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def clear(x):
|
||||
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
|
||||
x.clear()
|
||||
return x
|
||||
|
||||
self.checkScript(clear, (self.dict(),))
|
||||
|
||||
def test_get(self):
|
||||
def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
|
||||
def get(x, key):
|
||||
# type: (Dict[str, Tensor], str) -> Optional[Tensor]
|
||||
return x.get(key)
|
||||
|
||||
self.checkScript(get, (self.dict(), 'a'))
|
||||
self.checkScript(get, (self.dict(), "doesn't exist"))
|
||||
|
||||
def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
|
||||
def get_default(x, key):
|
||||
# type: (Dict[str, Tensor], str) -> Optional[Tensor]
|
||||
return x.get(key, torch.randn(2, 2))
|
||||
|
||||
self.checkScript(get, (self.dict(), 'a'))
|
||||
self.checkScript(get, (self.dict(), "doesn't exist"))
|
||||
|
||||
def test_get_boolkey(self):
|
||||
def get(x: Dict[bool, int], key: bool) -> Optional[int]:
|
||||
def get(x, key):
|
||||
# type: (Dict[bool, int], bool) -> Optional[int]
|
||||
return x.get(key)
|
||||
|
||||
self.checkScript(get, (self.dict_bool(), True))
|
||||
self.checkScript(get, (self.dict_bool(), False))
|
||||
|
||||
def get_default(x: Dict[bool, int], key: bool) -> int:
|
||||
def get_default(x, key):
|
||||
# type: (Dict[bool, int], bool) -> int
|
||||
return x.get(key, 42)
|
||||
|
||||
self.checkScript(get_default, (self.dict_bool(), True))
|
||||
self.checkScript(get_default, (self.dict_bool(), False))
|
||||
|
||||
def test_basic(self):
|
||||
def simple(x: Dict[str, int]) -> Dict[str, int]:
|
||||
def simple(x):
|
||||
# type: (Dict[str, int]) -> Dict[str, int]
|
||||
return x
|
||||
|
||||
self.checkScript(simple, ({'item': 20, 'other_item': 120},))
|
||||
|
||||
def index(x: Dict[str, int]) -> int:
|
||||
def index(x):
|
||||
# type: (Dict[str, int]) -> int
|
||||
return x['item']
|
||||
|
||||
self.checkScript(index, ({'item': 20, 'other_item': 120},))
|
||||
|
||||
def type_default() -> Dict[str, Tensor]:
|
||||
def type_default():
|
||||
# type: () -> Dict[str, Tensor]
|
||||
return {}
|
||||
|
||||
self.checkScript(type_default, ())
|
||||
|
||||
@torch.jit.script
|
||||
def missing_index(x: Dict[str, int]) -> int:
|
||||
def missing_index(x):
|
||||
# type: (Dict[str, int]) -> int
|
||||
return x['dne']
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "KeyError"):
|
||||
@ -1401,14 +1475,16 @@ class TestDict(JitTestCase):
|
||||
'''))
|
||||
self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3())
|
||||
|
||||
def list_of_dicts() -> List[Dict[str, Tensor]]:
|
||||
def list_of_dicts():
|
||||
# type: () -> List[Dict[str, Tensor]]
|
||||
return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}]
|
||||
|
||||
self.checkScript(list_of_dicts, ())
|
||||
|
||||
def test_mutability(self):
|
||||
@torch.jit.script
|
||||
def fn() -> Dict[str, int]:
|
||||
def fn():
|
||||
# type: () -> Dict[str, int]
|
||||
a = torch.jit.annotate(Dict[str, int], {})
|
||||
a['ok'] = 10
|
||||
return a
|
||||
@ -1418,12 +1494,14 @@ class TestDict(JitTestCase):
|
||||
def test_key_type(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "but instead found type"):
|
||||
@torch.jit.script
|
||||
def fn(a: Dict[str, int]) -> int:
|
||||
def fn(a):
|
||||
# type: (Dict[str, int]) -> int
|
||||
return a[None]
|
||||
|
||||
def test_loop(self):
|
||||
@torch.jit.script
|
||||
def fn(x: int) -> Dict[str, int]:
|
||||
def fn(x):
|
||||
# type: (int) -> Dict[str, int]
|
||||
a = torch.jit.annotate(Dict[str, int], {})
|
||||
for i in range(x):
|
||||
a['ok'] = i
|
||||
@ -1442,14 +1520,16 @@ class TestDict(JitTestCase):
|
||||
self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
|
||||
|
||||
def test_membership(self):
|
||||
def fn(x: Dict[int, int], y: int) -> int:
|
||||
def fn(x, y):
|
||||
# type: (Dict[int, int], int) -> int
|
||||
return x.get(y, 3)
|
||||
|
||||
d = {1: 2, 3: 4}
|
||||
self.checkScript(fn, (d, 3))
|
||||
self.checkScript(fn, (d, 2))
|
||||
|
||||
def optional(x: Dict[int, int], y: int) -> bool:
|
||||
def optional(x, y):
|
||||
# type: (Dict[int, int], int) -> bool
|
||||
res = x.get(y)
|
||||
return res is None
|
||||
|
||||
@ -1458,15 +1538,18 @@ class TestDict(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"):
|
||||
@torch.jit.script
|
||||
def bad_types(x: Dict[int, int], y: int) -> int:
|
||||
def bad_types(x, y):
|
||||
# type: (Dict[int, int], int) -> int
|
||||
return x.get(y) # noqa: T484
|
||||
|
||||
def test_dict_to_python(self):
|
||||
@torch.jit.ignore
|
||||
def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
|
||||
def python_lookup(my_dict, keys):
|
||||
# type: (Dict[str, int], List[str]) -> List[int]
|
||||
return [my_dict[k] for k in keys]
|
||||
|
||||
def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
|
||||
def fn(my_dict, keys):
|
||||
# type: (Dict[str, int], List[str]) -> List[int]
|
||||
return python_lookup(my_dict, keys)
|
||||
|
||||
a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
|
||||
@ -1518,7 +1601,8 @@ class TestDict(JitTestCase):
|
||||
key and value types produces an error.
|
||||
"""
|
||||
# This function uses a type comment.
|
||||
def fn_with_comment(input: Dict) -> Any:
|
||||
def fn_with_comment(input):
|
||||
# type: (Dict) -> Any
|
||||
return input
|
||||
|
||||
# This function uses Python3 style type annotations.
|
||||
|
@ -6,7 +6,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import sys
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -23,30 +22,36 @@ class OrigModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(OrigModule, self).__init__()
|
||||
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return inp1 + inp2 + 1
|
||||
|
||||
def two(self, input: Tensor) -> Tensor:
|
||||
def two(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return input + 2
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return input + self.one(input, input) + 1
|
||||
|
||||
class NewModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(NewModule, self).__init__()
|
||||
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return inp1 * inp2 + 1
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.one(input, input + 1)
|
||||
|
||||
class TestModuleInterface(JitTestCase):
|
||||
def test_not_submodule_interface_call(self):
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class TestNotModuleInterfaceCall(nn.Module):
|
||||
@ -56,7 +61,8 @@ class TestModuleInterface(JitTestCase):
|
||||
super(TestNotModuleInterfaceCall, self).__init__()
|
||||
self.proxy_mod = OrigModule()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.proxy_mod.two(input)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"):
|
||||
@ -66,51 +72,64 @@ class TestModuleInterface(JitTestCase):
|
||||
global OneTwoModule, OneTwoClass
|
||||
@torch.jit.interface
|
||||
class OneTwoModule(nn.Module):
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def two(self, x: Tensor) -> Tensor:
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
@torch.jit.interface
|
||||
class OneTwoClass(object):
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def two(self, x: Tensor) -> Tensor:
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class FooMod(nn.Module):
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return x + y
|
||||
|
||||
def two(self, x: Tensor) -> Tensor:
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return 2 * x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.one(self.two(x), x)
|
||||
|
||||
class BarMod(nn.Module):
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return x * y
|
||||
|
||||
def two(self, x: Tensor) -> Tensor:
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return 2 / x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.two(self.one(x, x))
|
||||
|
||||
@torch.jit.export
|
||||
def forward2(self, x: Tensor) -> Tensor:
|
||||
def forward2(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.two(self.one(x, x)) + 1
|
||||
|
||||
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
|
||||
return mod_list[0].forward(x) + mod_list[1].forward(x)
|
||||
|
||||
def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor:
|
||||
def use_class_interface(mod_list, x):
|
||||
# type: (List[OneTwoClass], Tensor) -> Tensor
|
||||
return mod_list[0].two(x) + mod_list[1].one(x, x)
|
||||
|
||||
scripted_foo_mod = torch.jit.script(FooMod())
|
||||
@ -120,7 +139,8 @@ class TestModuleInterface(JitTestCase):
|
||||
self.checkScript(use_class_interface,
|
||||
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
|
||||
|
||||
def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor:
|
||||
def call_module_interface_on_other_method(mod_interface, x):
|
||||
# type: (OneTwoModule, Tensor) -> Tensor
|
||||
return mod_interface.forward2(x)
|
||||
|
||||
# ensure error out when we call the module on the method other than the interface specified.
|
||||
@ -132,28 +152,35 @@ class TestModuleInterface(JitTestCase):
|
||||
global OneTwoModule
|
||||
@torch.jit.interface
|
||||
class OneTwoModule(nn.Module):
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def two(self, x: Tensor) -> Tensor:
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
@torch.jit.script
|
||||
def as_module_interface(x: OneTwoModule) -> OneTwoModule:
|
||||
def as_module_interface(x):
|
||||
# type: (OneTwoModule) -> OneTwoModule
|
||||
return x
|
||||
|
||||
@torch.jit.script
|
||||
class Foo(object):
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return x + y
|
||||
|
||||
def two(self, x: Tensor) -> Tensor:
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return 2 * x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.one(self.two(x), x)
|
||||
|
||||
# check class object is not a subtype of module interface
|
||||
@ -161,10 +188,12 @@ class TestModuleInterface(JitTestCase):
|
||||
as_module_interface(Foo())
|
||||
|
||||
class WrongMod(nn.Module):
|
||||
def two(self, x: int) -> int:
|
||||
def two(self, x):
|
||||
# type: (int) -> int
|
||||
return 2 * x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return x + torch.randn(3, self.two(3))
|
||||
|
||||
scripted_wrong_mod = torch.jit.script(WrongMod())
|
||||
@ -215,16 +244,19 @@ class TestModuleInterface(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
|
||||
@torch.jit.interface
|
||||
class InheritMod(nn.ReLU):
|
||||
def three(self, x: Tensor) -> Tensor:
|
||||
def three(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return 3 * x
|
||||
|
||||
def test_module_swap(self):
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class TestModule(nn.Module):
|
||||
@ -234,7 +266,8 @@ class TestModuleInterface(JitTestCase):
|
||||
super(TestModule, self).__init__()
|
||||
self.proxy_mod = OrigModule()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.proxy_mod.forward(input)
|
||||
|
||||
scripted_mod = torch.jit.script(TestModule())
|
||||
@ -252,17 +285,20 @@ class TestModuleInterface(JitTestCase):
|
||||
def test_module_swap_wrong_module(self):
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class NewModuleWrong(nn.Module):
|
||||
def __init__(self):
|
||||
super(NewModuleWrong, self).__init__()
|
||||
|
||||
def forward(self, input: int) -> int:
|
||||
def forward(self, input):
|
||||
# type: (int) -> int
|
||||
return input + 1
|
||||
|
||||
class TestModule(nn.Module):
|
||||
@ -272,7 +308,8 @@ class TestModuleInterface(JitTestCase):
|
||||
super(TestModule, self).__init__()
|
||||
self.proxy_mod = OrigModule()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.proxy_mod.forward(input)
|
||||
|
||||
scripted_mod = torch.jit.script(TestModule())
|
||||
@ -283,10 +320,12 @@ class TestModuleInterface(JitTestCase):
|
||||
def test_module_swap_no_lazy_compile(self):
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class TestModule(nn.Module):
|
||||
@ -296,17 +335,20 @@ class TestModuleInterface(JitTestCase):
|
||||
super(TestModule, self).__init__()
|
||||
self.proxy_mod = OrigModule()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.proxy_mod.forward(input)
|
||||
|
||||
class NewModuleMethodNotLazyCompile(nn.Module):
|
||||
def __init__(self):
|
||||
super(NewModuleMethodNotLazyCompile, self).__init__()
|
||||
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return inp1 * inp2 + 1
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return input + 1
|
||||
|
||||
scripted_mod = torch.jit.script(TestModule())
|
||||
@ -320,10 +362,12 @@ class TestModuleInterface(JitTestCase):
|
||||
super(NewModuleMethodManualExport, self).__init__()
|
||||
|
||||
@torch.jit.export
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return inp1 * inp2 + 1
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return input + 1
|
||||
|
||||
scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport())
|
||||
@ -337,7 +381,8 @@ class TestModuleInterface(JitTestCase):
|
||||
super(TestNoModuleInterface, self).__init__()
|
||||
self.proxy_mod = OrigModule()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.proxy_mod(input)
|
||||
|
||||
scripted_no_module_interface = torch.jit.script(TestNoModuleInterface())
|
||||
@ -352,10 +397,12 @@ class TestModuleInterface(JitTestCase):
|
||||
def test_script_module_as_interface_swap(self):
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class OrigScriptModule(torch.jit.ScriptModule):
|
||||
@ -363,11 +410,13 @@ class TestModuleInterface(JitTestCase):
|
||||
super(OrigScriptModule, self).__init__()
|
||||
|
||||
@torch.jit.script_method
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return inp1 + inp2 + 1
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return input + self.one(input, input) + 1
|
||||
|
||||
class NewScriptModule(torch.jit.ScriptModule):
|
||||
@ -375,11 +424,13 @@ class TestModuleInterface(JitTestCase):
|
||||
super(NewScriptModule, self).__init__()
|
||||
|
||||
@torch.jit.script_method
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
return inp1 * inp2 + 1
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.one(input, input + 1)
|
||||
|
||||
class TestNNModuleWithScriptModule(nn.Module):
|
||||
@ -389,7 +440,8 @@ class TestModuleInterface(JitTestCase):
|
||||
super(TestNNModuleWithScriptModule, self).__init__()
|
||||
self.proxy_mod = OrigScriptModule()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
def forward(self, input):
|
||||
# type: (Tensor) -> Tensor
|
||||
return self.proxy_mod.forward(input)
|
||||
|
||||
input = torch.randn(3, 4)
|
||||
@ -420,7 +472,8 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class ModInterface(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> int:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> int
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -467,7 +520,8 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class ModInterface(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> int:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> int
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -510,7 +564,8 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class ModInterface(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -555,7 +610,8 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class ModInterface(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -597,7 +653,8 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class ModInterface(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -631,7 +688,8 @@ class TestModuleInterface(JitTestCase):
|
||||
def test_module_apis_interface(self):
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
def one(self, inp1, inp2):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
class TestModule(nn.Module):
|
||||
|
@ -284,7 +284,8 @@ class TestRecursiveScript(JitTestCase):
|
||||
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
|
||||
|
||||
def test_class_compile(self):
|
||||
def other_fn(a: int, b: Tensor) -> Tensor:
|
||||
def other_fn(a, b):
|
||||
# type: (int, Tensor) -> Tensor
|
||||
return a * b
|
||||
|
||||
class B(object):
|
||||
@ -306,7 +307,8 @@ class TestRecursiveScript(JitTestCase):
|
||||
self.checkModule(N(), (torch.randn(2, 2),))
|
||||
|
||||
def test_error_stack(self):
|
||||
def d(x: int) -> int:
|
||||
def d(x):
|
||||
# type: (int) -> int
|
||||
return x + 10
|
||||
|
||||
def c(x):
|
||||
@ -329,7 +331,8 @@ class TestRecursiveScript(JitTestCase):
|
||||
checker.run(str(e))
|
||||
|
||||
def test_error_stack_module(self):
|
||||
def d(x: int) -> int:
|
||||
def d(x):
|
||||
# type: (int) -> int
|
||||
return x + 10
|
||||
|
||||
def c(x):
|
||||
@ -562,7 +565,8 @@ class TestRecursiveScript(JitTestCase):
|
||||
self.a = 4
|
||||
self.inner = Inner2()
|
||||
|
||||
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
|
||||
def __setstate__(self, obj):
|
||||
# type: (Tuple[int, Inner2]) -> None
|
||||
a, inner = obj
|
||||
self.a = a
|
||||
self.inner = inner
|
||||
|
@ -680,7 +680,8 @@ class TestSaveLoad(JitTestCase):
|
||||
"""
|
||||
@torch.jit.interface
|
||||
class MyInterface(object):
|
||||
def bar(self, x: Tensor) -> Tensor:
|
||||
def bar(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
@torch.jit.script
|
||||
@ -710,7 +711,8 @@ class TestSaveLoad(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class MyInterface(object):
|
||||
def not_bar(self, x: Tensor) -> Tensor:
|
||||
def not_bar(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
@torch.jit.script # noqa: F811
|
||||
@ -765,7 +767,8 @@ class TestSaveLoad(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class MyInterface(object):
|
||||
def bar(self, x: Tensor) -> Tensor:
|
||||
def bar(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
@torch.jit.script
|
||||
@ -806,7 +809,8 @@ class TestSaveLoad(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class MyInterface(object):
|
||||
def not_bar(self, x: Tensor) -> Tensor:
|
||||
def not_bar(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
@torch.jit.script # noqa F811
|
||||
|
@ -25,7 +25,7 @@ from torch import Tensor
|
||||
# Standard library
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
import warnings
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -1862,11 +1862,13 @@ class TestTracer(JitTestCase):
|
||||
class TestMixTracingScripting(JitTestCase):
|
||||
def test_trace_script(self):
|
||||
@torch.jit.script
|
||||
def func1(x: Tuple[Tensor, Tensor]) -> Tensor:
|
||||
def func1(x):
|
||||
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
||||
return x[0] + x[1]
|
||||
|
||||
@torch.jit.script
|
||||
def func2(x: List[Tensor]) -> Tensor:
|
||||
def func2(x):
|
||||
# type: (List[Tensor]) -> Tensor
|
||||
return x[0] + x[1]
|
||||
|
||||
a = torch.randn(5)
|
||||
@ -1876,7 +1878,8 @@ class TestMixTracingScripting(JitTestCase):
|
||||
self.checkTrace(func2, ((a, b),))
|
||||
|
||||
@torch.jit.script
|
||||
def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Tensor:
|
||||
def func3(x, method='bilinear', align_corners=True):
|
||||
# type: (Tensor, str, bool) -> Tensor
|
||||
hw = x.shape[2:4]
|
||||
return F.interpolate(x, hw, mode=method, align_corners=align_corners)
|
||||
|
||||
@ -1884,7 +1887,8 @@ class TestMixTracingScripting(JitTestCase):
|
||||
self.checkTrace(func3, (inp,))
|
||||
|
||||
@torch.jit.script
|
||||
def func4(x: Tensor, a: List[Optional[str]]) -> Tensor:
|
||||
def func4(x, a):
|
||||
# type: (Tensor, List[Optional[str]]) -> Tensor
|
||||
if len(a) == 2:
|
||||
return x + 2
|
||||
else:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
@ -50,7 +50,8 @@ class TestWith(JitTestCase):
|
||||
def __exit__(self, type: Any, value: Any, tb: Any):
|
||||
self.count.sub_(0.3)
|
||||
|
||||
def test_basic(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_basic(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""Basic test with one with-statement."""
|
||||
|
||||
c = Context(1)
|
||||
@ -61,7 +62,8 @@ class TestWith(JitTestCase):
|
||||
y *= c.count
|
||||
return y
|
||||
|
||||
def test_pass(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_pass(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test with a pass statement inside a with-statement. Although
|
||||
the body of the with is empty, __enter__ and __exit__ should
|
||||
@ -75,7 +77,8 @@ class TestWith(JitTestCase):
|
||||
x *= c.count
|
||||
return x
|
||||
|
||||
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
|
||||
def test_early_return(x, c):
|
||||
# type: (Tensor, Context) -> Tensor
|
||||
"""
|
||||
Test that returning early from inside a with-statement works
|
||||
as expected.
|
||||
@ -87,7 +90,8 @@ class TestWith(JitTestCase):
|
||||
x = y + y
|
||||
return x
|
||||
|
||||
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
|
||||
def test_conditional_early_return(x, c):
|
||||
# type: (Tensor, Context) -> Tensor
|
||||
"""
|
||||
Test that conditionally returning early from inside a with-statement works
|
||||
as expected.
|
||||
@ -100,7 +104,8 @@ class TestWith(JitTestCase):
|
||||
x = y + y
|
||||
return x
|
||||
|
||||
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
|
||||
def test_break(x, c, l):
|
||||
# type: (Tensor, Context, List[int]) -> Tensor
|
||||
"""
|
||||
Test that breaking early from inside a with-statement works
|
||||
as expected.
|
||||
@ -113,7 +118,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return x
|
||||
|
||||
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
|
||||
def test_continue(x, c, l):
|
||||
# type: (Tensor, Context, List[int]) -> Tensor
|
||||
"""
|
||||
Test that using continue inside a with-statement works
|
||||
as expected.
|
||||
@ -126,7 +132,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return x
|
||||
|
||||
def test_serial(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_serial(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test two with-statements in a row.
|
||||
"""
|
||||
@ -140,7 +147,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return y
|
||||
|
||||
def test_nested(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_nested(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test nested with-statements.
|
||||
"""
|
||||
@ -154,7 +162,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return y
|
||||
|
||||
def test_combined(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_combined(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test a with-statement with multiple with items.
|
||||
"""
|
||||
@ -206,7 +215,8 @@ class TestWith(JitTestCase):
|
||||
def __exit__(self, type: Any, value: Any, tb: Any):
|
||||
self.count.sub_(0.3)
|
||||
|
||||
def test_basic(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_basic(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""Basic test with one with-statement."""
|
||||
|
||||
c = Context(1)
|
||||
@ -217,7 +227,8 @@ class TestWith(JitTestCase):
|
||||
y *= c.count
|
||||
return y
|
||||
|
||||
def test_pass(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_pass(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test with a pass statement inside a with-statement. Although
|
||||
the body of the with is empty, __enter__ and __exit__ should
|
||||
@ -231,7 +242,8 @@ class TestWith(JitTestCase):
|
||||
x *= c.count
|
||||
return x
|
||||
|
||||
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
|
||||
def test_early_return(x, c):
|
||||
# type: (Tensor, Context) -> Tensor
|
||||
"""
|
||||
Test that returning early from inside a with-statement works
|
||||
as expected.
|
||||
@ -243,7 +255,8 @@ class TestWith(JitTestCase):
|
||||
x = y + y
|
||||
return x
|
||||
|
||||
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
|
||||
def test_conditional_early_return(x, c):
|
||||
# type: (Tensor, Context) -> Tensor
|
||||
"""
|
||||
Test that conditionally returning early from inside a with-statement works
|
||||
as expected.
|
||||
@ -256,7 +269,8 @@ class TestWith(JitTestCase):
|
||||
x = y + y
|
||||
return x
|
||||
|
||||
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
|
||||
def test_break(x, c, l):
|
||||
# type: (Tensor, Context, List[int]) -> Tensor
|
||||
"""
|
||||
Test that breaking early from inside a with-statement works
|
||||
as expected.
|
||||
@ -269,7 +283,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return x
|
||||
|
||||
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
|
||||
def test_continue(x, c, l):
|
||||
# type: (Tensor, Context, List[int]) -> Tensor
|
||||
"""
|
||||
Test that using continue inside a with-statement works
|
||||
as expected.
|
||||
@ -282,7 +297,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return x
|
||||
|
||||
def test_serial(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_serial(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test two with-statements in a row.
|
||||
"""
|
||||
@ -296,7 +312,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return y
|
||||
|
||||
def test_nested(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_nested(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test nested with-statements.
|
||||
"""
|
||||
@ -310,7 +327,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
return y
|
||||
|
||||
def test_combined(x: torch.Tensor) -> torch.Tensor:
|
||||
def test_combined(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
"""
|
||||
Test a with-statement with multiple with items.
|
||||
"""
|
||||
@ -363,11 +381,13 @@ class TestWith(JitTestCase):
|
||||
self.count.sub_(0.3)
|
||||
|
||||
@torch.jit.script
|
||||
def method_that_raises() -> torch.Tensor:
|
||||
def method_that_raises():
|
||||
# type: () -> Tensor
|
||||
raise Exception("raised exception")
|
||||
|
||||
@torch.jit.script
|
||||
def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
|
||||
def test_exception(x, c):
|
||||
# type: (Tensor, Context) -> Tensor
|
||||
"""
|
||||
Test the case in which an exception is thrown while executing the body of a with-statement.
|
||||
"""
|
||||
@ -377,7 +397,8 @@ class TestWith(JitTestCase):
|
||||
return x
|
||||
|
||||
@torch.jit.script
|
||||
def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
|
||||
def test_exception_nested(x, c):
|
||||
# type: (Tensor, Context) -> Tensor
|
||||
"""
|
||||
Test the case in which an exception is thrown while executing the body of a nested with-statement.
|
||||
"""
|
||||
@ -388,7 +409,8 @@ class TestWith(JitTestCase):
|
||||
return x
|
||||
|
||||
@torch.jit.script
|
||||
def with_that_raises(c: Context) -> torch.Tensor:
|
||||
def with_that_raises(c):
|
||||
# type: (Context) -> Tensor
|
||||
a = torch.tensor([1])
|
||||
|
||||
with c as _:
|
||||
@ -397,7 +419,8 @@ class TestWith(JitTestCase):
|
||||
return a
|
||||
|
||||
@torch.jit.script
|
||||
def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
|
||||
def test_exception_fn_call(x, c):
|
||||
# type: (Tensor, Context) -> Tensor
|
||||
"""
|
||||
Test the case in which an exception is thrown while there are active with-statements in two different
|
||||
frames.
|
||||
@ -483,25 +506,29 @@ class TestWith(JitTestCase):
|
||||
def __exit__(self, type: Any, value: int, tb: int):
|
||||
pass
|
||||
|
||||
def test_no_enter_no_exit(x: torch.Tensor, c: NoEnterNoExit) -> torch.Tensor:
|
||||
def test_no_enter_no_exit(x, c):
|
||||
# type: (Tensor, NoEnterNoExit) -> Tensor
|
||||
with c as _:
|
||||
pass
|
||||
|
||||
return x
|
||||
|
||||
def test_bad_enter(x: torch.Tensor, c: BadEnter) -> torch.Tensor:
|
||||
def test_bad_enter(x, c):
|
||||
# type: (Tensor, BadEnter) -> Tensor
|
||||
with c as _:
|
||||
pass
|
||||
|
||||
return x
|
||||
|
||||
def test_bad_exit(x: torch.Tensor, c: BadExit) -> torch.Tensor:
|
||||
def test_bad_exit(x, c):
|
||||
# type: (Tensor, BadExit) -> Tensor
|
||||
with c as _:
|
||||
pass
|
||||
|
||||
return x
|
||||
|
||||
def test_exit_incorrect_types(x: torch.Tensor, c: ExitIncorrectTypes) -> torch.Tensor:
|
||||
def test_exit_incorrect_types(x, c):
|
||||
# type: (Tensor, ExitIncorrectTypes) -> Tensor
|
||||
with c as _:
|
||||
pass
|
||||
|
||||
@ -538,7 +565,8 @@ class TestWith(JitTestCase):
|
||||
"""
|
||||
|
||||
# Basic no_grad test.
|
||||
def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
def test_no_grad(x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
with torch.no_grad():
|
||||
w = x + y
|
||||
|
||||
@ -555,7 +583,8 @@ class TestWith(JitTestCase):
|
||||
|
||||
# Test assignment of a grad-less Tensor to a Tensor with gradients
|
||||
# in a no_grad block.
|
||||
def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
def test_no_grad_assignment(x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
with torch.no_grad():
|
||||
x[0] = y
|
||||
|
||||
@ -574,11 +603,13 @@ class TestWith(JitTestCase):
|
||||
super().__init__()
|
||||
|
||||
@torch.jit.ignore
|
||||
def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
def adder(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
w = x + y
|
||||
return w
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
with torch.no_grad():
|
||||
w = self.adder(x, y)
|
||||
|
||||
@ -594,7 +625,8 @@ class TestWith(JitTestCase):
|
||||
Check that torch.autograd.profiler.record_function context manager is
|
||||
torchscriptable.
|
||||
"""
|
||||
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
def with_rf(x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
with torch.autograd.profiler.record_function("foo"):
|
||||
# Nested record_function.
|
||||
with torch.autograd.profiler.record_function("nested"):
|
||||
|
Reference in New Issue
Block a user