Clean up some type annotations in caffe2/test (#49943)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49943

Upgrades type annotations from Python2 to Python3

Test Plan: Sandcastle tests

Reviewed By: xush6528

Differential Revision: D25717534

fbshipit-source-id: 5aedea4db07efca126ffb6daee79617c30a67146
This commit is contained in:
Richard Barnes
2021-01-13 09:55:15 -08:00
committed by Facebook GitHub Bot
parent 7d0eecc666
commit a4383a69d4
9 changed files with 89 additions and 143 deletions

View File

@ -1,22 +1,24 @@
import numpy as np
from typing import Tuple
import io
import itertools
import sys
import unittest
import itertools
import torch.onnx
import torch.onnx.operators
from torch.onnx import ExportTypes
import numpy as np
from debug_embed_params import run_embed_params
from torch import nn
from torch.autograd import Variable, function
import torch.utils.model_zoo as model_zoo
from torch.nn.utils import rnn as rnn_utils
from debug_embed_params import run_embed_params
import io
from torch.onnx import ExportTypes
import torch.onnx
import torch.onnx.operators
import torch.utils.model_zoo as model_zoo
# Import various models for testing
from torchvision.models.alexnet import alexnet
from torchvision.models.inception import inception_v3
from torchvision.models.densenet import densenet121
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet50
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
@ -1981,8 +1983,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
def test_tuple_input_output(self):
class TupleModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
def forward(self, a: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return a
x = (torch.randn(3, 4), torch.randn(4, 3))
@ -1992,8 +1993,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
def test_nested_tuple_input_output(self):
class NestedTupleModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a, b):
# type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
def forward(self, a: torch.Tensor, b: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
return a + b[0] + b[1][0] + b[1][1]
x = torch.randn(4, 5)

View File

@ -2007,8 +2007,9 @@ class TestDeprecatedJitQuantized(JitTestCase):
self.cell = cell
@torch.jit.script_method
def forward(self, x, hiddens):
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
def forward(self, x: torch.Tensor,
hiddens: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.cell(x, hiddens)
else:
@ -2018,8 +2019,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
self.cell = cell
@torch.jit.script_method
def forward(self, x, hiddens):
# type: (torch.Tensor, torch.Tensor) -> torch.Tensor
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor:
return self.cell(x, hiddens)
cell = ScriptWrapper(cell)
@ -2131,8 +2131,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
self.cell = cell
@torch.jit.script_method
def forward(self, x, hiddens):
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self.cell(x, hiddens)
compare_quantized_unquantized(ScriptWrapper, cell)

View File

@ -74,9 +74,10 @@ from torch.testing._internal.jit_utils import get_forward_graph
from torch.jit._recursive import wrap_cpp_module
# Standard library
from typing import List, Tuple
import io
import itertools
import unittest
import io
class TestQuantizeJitPasses(QuantizationTestCase):
""" Test graph mode quantization passes used by quantize_jit
@ -742,8 +743,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
.run(m.graph)
def test_insert_observers_propagate_observed_for_function(self):
def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
@ -1126,8 +1126,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.use_skip = True
def forward(self, x, cond):
# type: (Tensor, bool) -> Tensor
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
# to avoid being frozen
self.use_skip = cond
if self.use_skip:
@ -1227,8 +1226,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
super(ComplexModel, self).__init__()
self.layers = torch.nn.ModuleList([SimpleLinearLayer() for i in range(2)])
def forward(self, x):
# type: (torch.Tensor) -> List[torch.Tensor]
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
states = []
for layer in self.layers:
val = layer(x)
@ -1324,8 +1322,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
pass
class TestModule(torch.nn.Module):
@ -2428,8 +2425,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x1 = self.conv1(x)
x2 = self.conv2(x)
return x1, x2
@ -2919,8 +2915,7 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
super(Res, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x, cond):
# type: (Tensor, bool) -> Tensor
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
if cond:
return torch.nn.functional.linear(x, self.weight)
else:

View File

@ -557,8 +557,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_scalar_arg_cuda(self):
def fn_test_scalar_arg(x, p):
# type: (Tensor, float) -> Tensor
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
return p * (x * x + x)
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
@ -570,8 +569,7 @@ class TestFuser(JitTestCase):
# use another function otherwise we will bailout
# and won't be able to do fused checks
def fn_test_scalar_arg_requires_grad(x, p):
# type: (Tensor, float) -> Tensor
def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor:
return p * (x * x + x)
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)

View File

@ -763,8 +763,7 @@ class TestTEFuser(JitTestCase):
def test_scalar_arg(self):
for device in self.devices:
def fn_test_scalar_arg(x, p):
# type: (Tensor, float) -> Tensor
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
return p * (x * x + x)
x = torch.randn(4, 4, dtype=torch.float, device=device)
@ -776,8 +775,7 @@ class TestTEFuser(JitTestCase):
# use another function otherwise we will bailout
# and won't be able to do fused checks
def fn_test_scalar_arg_requires_grad(x, p):
# type: (Tensor, float) -> Tensor
def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor:
return p * (x * x + x)
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)

View File

@ -1,10 +1,11 @@
from collections import namedtuple
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
from torch import jit
from textwrap import dedent
from typing import NamedTuple, List, Optional, Dict, Tuple, Any
from jit.test_module_interface import TestModuleInterface # noqa: F401
import inspect
import unittest
@ -189,8 +190,7 @@ class TestScriptPy3(JitTestCase):
super().__init__()
@torch.jit.ignore
def foo(self, x, z):
# type: (Tensor, Tensor) -> Tuple[GG, GG]
def foo(self, x: torch.Tensor, z: torch.Tensor) -> Tuple[GG, GG]:
return GG(x, z), GG(x, z)
def forward(self, x, z):
@ -412,8 +412,7 @@ class TestScriptPy3(JitTestCase):
"""
Test that using an optional with no contained types produces an error.
"""
def fn_with_comment(x):
# type: (torch.Tensor) -> Optional
def fn_with_comment(x: torch.Tensor) -> Optional:
return (x, x)
def annotated_fn(x: torch.Tensor) -> Optional:
@ -437,8 +436,7 @@ class TestScriptPy3(JitTestCase):
"""
Test that using a tuple with no contained types produces an error.
"""
def fn_with_comment(x):
# type: (torch.Tensor) -> Tuple
def fn_with_comment(x: torch.Tensor) -> Tuple:
return (x, x)
def annotated_fn(x: torch.Tensor) -> Tuple:
@ -733,42 +731,33 @@ class TestScriptPy3(JitTestCase):
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
pass
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: torch.Tensor) -> torch.Tensor:
pass
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class FooMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.one(self.two(x), x)
class BarMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y
def two(self, x):
# type: (Tensor) -> Tensor
def two(self, x: torch.Tensor) -> torch.Tensor:
return 2 / x
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.two(self.one(x, x))
class M(nn.Module):
@ -778,8 +767,7 @@ class TestScriptPy3(JitTestCase):
super(M, self).__init__()
self.sub = BarMod()
def forward(self, x):
# type: (Tensor) -> Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.sub.forward(x)
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):

View File

@ -1,28 +1,26 @@
from test_jit import JitTestCase
from torch.testing._internal.common_utils import run_tests
from typing import List, Tuple
class TestScript(JitTestCase):
def test_str_ops(self):
def test_str_is(s):
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
def test_str_is(s: str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
s.isalnum(), s.isalpha(), s.isdecimal(), s.isnumeric(), \
s.isidentifier(), s.istitle(), s.isprintable()
def test_str_to(s):
# type: (str) -> Tuple[str, str, str, str, str]
def test_str_to(s: str) -> Tuple[str, str, str, str, str]:
return s.upper(), s.lower(), s.capitalize(), s.title(), s.swapcase()
def test_str_strip(s):
# type: (str) -> Tuple[str, str, str]
def test_str_strip(s: str) -> Tuple[str, str, str]:
return (
s.lstrip(),
s.rstrip(),
s.strip(),
)
def test_str_strip_char_set(s, char_set):
# type: (str, str) -> Tuple[str, str, str]
def test_str_strip_char_set(s: str, char_set: str) -> Tuple[str, str, str]:
return (
s.lstrip(char_set),
s.rstrip(char_set),
@ -34,44 +32,34 @@ class TestScript(JitTestCase):
"more strings with spaces", "Titular Strings", "\x0acan'tprintthis",
"spaces at the end ", " begin"]
def test_str_center(i, s):
# type: (int, str) -> str
def test_str_center(i: int, s: str) -> str:
return s.center(i)
def test_str_center_fc(i, s):
# type: (int, str) -> str
def test_str_center_fc(i: int, s: str) -> str:
return s.center(i, '*')
def test_str_center_error(s):
# type: (str) -> str
def test_str_center_error(s: str) -> str:
return s.center(10, '**')
def test_ljust(s, i):
# type: (str, int) -> str
def test_ljust(s: str, i: int) -> str:
return s.ljust(i)
def test_ljust_fc(s, i, fc):
# type: (str, int, str) -> str
def test_ljust_fc(s: str, i: int, fc: str) -> str:
return s.ljust(i, fc)
def test_ljust_fc_err(s):
# type: (str) -> str
def test_ljust_fc_err(s: str) -> str:
return s.ljust(10, '**')
def test_rjust(s, i):
# type: (str, int) -> str
def test_rjust(s: str, i: int) -> str:
return s.rjust(i)
def test_rjust_fc(s, i, fc):
# type: (str, int, str) -> str
def test_rjust_fc(s: str, i: int, fc: str) -> str:
return s.rjust(i, fc)
def test_rjust_fc_err(s):
# type: (str) -> str
def test_rjust_fc_err(s: str) -> str:
return s.rjust(10, '**')
def test_zfill(s, i):
# type: (str, int) -> str
def test_zfill(s: str, i: int) -> str:
return s.zfill(i)
for input in inputs:
@ -93,8 +81,7 @@ class TestScript(JitTestCase):
test_str_center_error("error")
test_ljust("error")
def test_count():
# type: () -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]
def test_count() -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]:
return (
"hello".count("h"),
"hello".count("h", 0, 1),
@ -111,8 +98,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_count, ())
def test_endswith():
# type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
def test_endswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
return (
"hello".endswith("lo"),
"hello".endswith("lo", 0),
@ -131,8 +117,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_endswith, ())
def test_startswith():
# type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
def test_startswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
return (
"hello".startswith("lo"),
"hello".startswith("lo", 0),
@ -151,8 +136,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_startswith, ())
def test_expandtabs():
# type: () -> Tuple[str, str, str, str, str, str]
def test_expandtabs() -> Tuple[str, str, str, str, str, str]:
return (
'xyz\t82345\tabc'.expandtabs(),
'xyz\t32345\tabc'.expandtabs(3),
@ -163,8 +147,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_expandtabs, ())
def test_rfind():
# type: () -> Tuple[int, int, int, int, int, int, int, int, int]
def test_rfind() -> Tuple[int, int, int, int, int, int, int, int, int]:
return (
"hello123abc".rfind("llo"),
"hello123abc".rfind("12"),
@ -178,8 +161,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_rfind, ())
def test_find():
# type: () -> Tuple[int, int, int, int, int, int, int, int, int]
def test_find() -> Tuple[int, int, int, int, int, int, int, int, int]:
return (
"hello123abc".find("llo"),
"hello123abc".find("12"),
@ -193,8 +175,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_find, ())
def test_index():
# type: () -> Tuple[int, int, int, int, int, int]
def test_index() -> Tuple[int, int, int, int, int, int]:
return (
"hello123abc".index("llo"),
"hello123abc".index("12"),
@ -205,8 +186,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_index, ())
def test_rindex():
# type: () -> Tuple[int, int, int, int, int, int]
def test_rindex() -> Tuple[int, int, int, int, int, int]:
return (
"hello123abc".rindex("llo"),
"hello123abc".rindex("12"),
@ -217,8 +197,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_rindex, ())
def test_replace():
# type: () -> Tuple[str, str, str, str, str, str, str]
def test_replace() -> Tuple[str, str, str, str, str, str, str]:
return (
"hello123abc".replace("llo", "sdf"),
"ff".replace("f", "ff"),
@ -230,11 +209,9 @@ class TestScript(JitTestCase):
)
self.checkScript(test_replace, ())
def test_partition():
"""
type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str],
Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]]
"""
def test_partition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
Tuple[str, str, str]]:
return (
"hello123abc".partition("llo"),
"ff".partition("f"),
@ -246,11 +223,9 @@ class TestScript(JitTestCase):
)
self.checkScript(test_partition, ())
def test_rpartition():
"""
type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str],
Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]]
"""
def test_rpartition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
Tuple[str, str, str]]:
return (
"hello123abc".rpartition("llo"),
"ff".rpartition("f"),
@ -262,11 +237,8 @@ class TestScript(JitTestCase):
)
self.checkScript(test_rpartition, ())
def test_split():
"""
type: () -> Tuple[List[str], List[str], List[str], List[str], List[str],
List[str], List[str], List[str], List[str], List[str], List[str]]
"""
def test_split() -> Tuple[List[str], List[str], List[str], List[str], List[str],
List[str], List[str], List[str], List[str], List[str], List[str]]:
return (
"a a a a a".split(),
"a a a a a".split(),
@ -290,8 +262,8 @@ class TestScript(JitTestCase):
self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception,
"empty separator")
def test_rsplit():
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
def test_rsplit() -> Tuple[List[str], List[str], List[str], List[str], List[str],
List[str], List[str], List[str], List[str]]:
return (
"a a a a a".rsplit(),
" a a a a a ".rsplit(" "),
@ -305,8 +277,8 @@ class TestScript(JitTestCase):
)
self.checkScript(test_rsplit, ())
def test_splitlines():
# type: () -> Tuple[ List[str], List[str], List[str], List[str], List[str], List[str] ]
def test_splitlines() -> Tuple[List[str], List[str], List[str], List[str],
List[str], List[str]]:
return (
"hello\ntest".splitlines(),
"hello\n\ntest\n".splitlines(),
@ -317,8 +289,7 @@ class TestScript(JitTestCase):
)
self.checkScript(test_splitlines, ())
def test_str_cmp(a, b):
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
def test_str_cmp(a: str, b: str) -> Tuple[bool, bool, bool, bool, bool, bool]:
return a != b, a == b, a < b, a > b, a <= b, a >= b
for i in range(len(inputs) - 1):

View File

@ -3,6 +3,7 @@ import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from typing import Dict, Optional
class StaticRuntime:
def __init__(self, scripted):
@ -27,8 +28,7 @@ class StaticRuntime:
)
def linear_shim(input, weight, bias=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
def linear_shim(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = input.matmul(weight.t())
if bias is not None:
output += bias
@ -116,7 +116,7 @@ def loop_graph(a, b, iters : int):
def output_graph(a, b, c, iters : int):
s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s
d : Dict[int, Tensor] = {}
d : Dict[int, torch.Tensor] = {}
for i in range(iters):
d[i] = k + i
return d

View File

@ -1163,13 +1163,11 @@ class TestTensorExprFuser(BaseTestClass):
def test_scalar(self):
@torch.jit.script
def test_float(x, y, z, a, b):
# type: (Tensor, Tensor, Tensor, float, float) -> Tensor
def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
@torch.jit.script
def test_int(x, y, z, a, b):
# type: (Tensor, Tensor, Tensor, int, int) -> Tensor
def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor:
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
for test in (test_float, test_int):
@ -1186,8 +1184,7 @@ class TestTensorExprFuser(BaseTestClass):
def test_loop(self):
@torch.jit.script
def test(x, y, z):
# type: (Tensor, Tensor, int) -> Tensor
def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor:
b = y
for i in range(0, z):
a = x + y