mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Clean up some type annotations in caffe2/test (#49943)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49943 Upgrades type annotations from Python2 to Python3 Test Plan: Sandcastle tests Reviewed By: xush6528 Differential Revision: D25717534 fbshipit-source-id: 5aedea4db07efca126ffb6daee79617c30a67146
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7d0eecc666
commit
a4383a69d4
@ -1,22 +1,24 @@
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
import io
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
import itertools
|
||||
|
||||
import torch.onnx
|
||||
import torch.onnx.operators
|
||||
from torch.onnx import ExportTypes
|
||||
import numpy as np
|
||||
|
||||
from debug_embed_params import run_embed_params
|
||||
from torch import nn
|
||||
from torch.autograd import Variable, function
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch.nn.utils import rnn as rnn_utils
|
||||
from debug_embed_params import run_embed_params
|
||||
import io
|
||||
from torch.onnx import ExportTypes
|
||||
import torch.onnx
|
||||
import torch.onnx.operators
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
# Import various models for testing
|
||||
from torchvision.models.alexnet import alexnet
|
||||
from torchvision.models.inception import inception_v3
|
||||
from torchvision.models.densenet import densenet121
|
||||
from torchvision.models.inception import inception_v3
|
||||
from torchvision.models.resnet import resnet50
|
||||
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
|
||||
|
||||
@ -1981,8 +1983,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
def test_tuple_input_output(self):
|
||||
class TupleModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
|
||||
def forward(self, a: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return a
|
||||
|
||||
x = (torch.randn(3, 4), torch.randn(4, 3))
|
||||
@ -1992,8 +1993,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
def test_nested_tuple_input_output(self):
|
||||
class NestedTupleModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a, b):
|
||||
# type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
|
||||
def forward(self, a: torch.Tensor, b: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
|
||||
return a + b[0] + b[1][0] + b[1][1]
|
||||
|
||||
x = torch.randn(4, 5)
|
||||
|
@ -2007,8 +2007,9 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
||||
self.cell = cell
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x, hiddens):
|
||||
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
|
||||
def forward(self, x: torch.Tensor,
|
||||
hiddens: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.cell(x, hiddens)
|
||||
else:
|
||||
|
||||
@ -2018,8 +2019,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
||||
self.cell = cell
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x, hiddens):
|
||||
# type: (torch.Tensor, torch.Tensor) -> torch.Tensor
|
||||
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor:
|
||||
return self.cell(x, hiddens)
|
||||
|
||||
cell = ScriptWrapper(cell)
|
||||
@ -2131,8 +2131,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
||||
self.cell = cell
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x, hiddens):
|
||||
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
||||
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.cell(x, hiddens)
|
||||
|
||||
compare_quantized_unquantized(ScriptWrapper, cell)
|
||||
|
@ -74,9 +74,10 @@ from torch.testing._internal.jit_utils import get_forward_graph
|
||||
from torch.jit._recursive import wrap_cpp_module
|
||||
|
||||
# Standard library
|
||||
from typing import List, Tuple
|
||||
import io
|
||||
import itertools
|
||||
import unittest
|
||||
import io
|
||||
|
||||
class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
""" Test graph mode quantization passes used by quantize_jit
|
||||
@ -742,8 +743,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
.run(m.graph)
|
||||
|
||||
def test_insert_observers_propagate_observed_for_function(self):
|
||||
def channel_shuffle(x, groups):
|
||||
# type: (torch.Tensor, int) -> torch.Tensor
|
||||
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
|
||||
batchsize, num_channels, height, width = x.data.size()
|
||||
channels_per_group = num_channels // groups
|
||||
# reshape
|
||||
@ -1126,8 +1126,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
self.conv = torch.nn.Conv2d(3, 3, 1).float()
|
||||
self.use_skip = True
|
||||
|
||||
def forward(self, x, cond):
|
||||
# type: (Tensor, bool) -> Tensor
|
||||
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
|
||||
# to avoid being frozen
|
||||
self.use_skip = cond
|
||||
if self.use_skip:
|
||||
@ -1227,8 +1226,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
super(ComplexModel, self).__init__()
|
||||
self.layers = torch.nn.ModuleList([SimpleLinearLayer() for i in range(2)])
|
||||
|
||||
def forward(self, x):
|
||||
# type: (torch.Tensor) -> List[torch.Tensor]
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
states = []
|
||||
for layer in self.layers:
|
||||
val = layer(x)
|
||||
@ -1324,8 +1322,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class ModInterface(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -2428,8 +2425,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
|
||||
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
|
||||
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tuple[Tensor, Tensor]
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(x)
|
||||
return x1, x2
|
||||
@ -2919,8 +2915,7 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
||||
super(Res, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(5, 5))
|
||||
|
||||
def forward(self, x, cond):
|
||||
# type: (Tensor, bool) -> Tensor
|
||||
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
|
||||
if cond:
|
||||
return torch.nn.functional.linear(x, self.weight)
|
||||
else:
|
||||
|
@ -557,8 +557,7 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_scalar_arg_cuda(self):
|
||||
def fn_test_scalar_arg(x, p):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
|
||||
return p * (x * x + x)
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
@ -570,8 +569,7 @@ class TestFuser(JitTestCase):
|
||||
|
||||
# use another function otherwise we will bailout
|
||||
# and won't be able to do fused checks
|
||||
def fn_test_scalar_arg_requires_grad(x, p):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor:
|
||||
return p * (x * x + x)
|
||||
|
||||
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
|
||||
|
@ -763,8 +763,7 @@ class TestTEFuser(JitTestCase):
|
||||
|
||||
def test_scalar_arg(self):
|
||||
for device in self.devices:
|
||||
def fn_test_scalar_arg(x, p):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
|
||||
return p * (x * x + x)
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device=device)
|
||||
@ -776,8 +775,7 @@ class TestTEFuser(JitTestCase):
|
||||
|
||||
# use another function otherwise we will bailout
|
||||
# and won't be able to do fused checks
|
||||
def fn_test_scalar_arg_requires_grad(x, p):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor:
|
||||
return p * (x * x + x)
|
||||
|
||||
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
|
||||
|
@ -1,10 +1,11 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing import FileCheck
|
||||
from torch import jit
|
||||
from textwrap import dedent
|
||||
from typing import NamedTuple, List, Optional, Dict, Tuple, Any
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
import inspect
|
||||
import unittest
|
||||
@ -189,8 +190,7 @@ class TestScriptPy3(JitTestCase):
|
||||
super().__init__()
|
||||
|
||||
@torch.jit.ignore
|
||||
def foo(self, x, z):
|
||||
# type: (Tensor, Tensor) -> Tuple[GG, GG]
|
||||
def foo(self, x: torch.Tensor, z: torch.Tensor) -> Tuple[GG, GG]:
|
||||
return GG(x, z), GG(x, z)
|
||||
|
||||
def forward(self, x, z):
|
||||
@ -412,8 +412,7 @@ class TestScriptPy3(JitTestCase):
|
||||
"""
|
||||
Test that using an optional with no contained types produces an error.
|
||||
"""
|
||||
def fn_with_comment(x):
|
||||
# type: (torch.Tensor) -> Optional
|
||||
def fn_with_comment(x: torch.Tensor) -> Optional:
|
||||
return (x, x)
|
||||
|
||||
def annotated_fn(x: torch.Tensor) -> Optional:
|
||||
@ -437,8 +436,7 @@ class TestScriptPy3(JitTestCase):
|
||||
"""
|
||||
Test that using a tuple with no contained types produces an error.
|
||||
"""
|
||||
def fn_with_comment(x):
|
||||
# type: (torch.Tensor) -> Tuple
|
||||
def fn_with_comment(x: torch.Tensor) -> Tuple:
|
||||
return (x, x)
|
||||
|
||||
def annotated_fn(x: torch.Tensor) -> Tuple:
|
||||
@ -733,42 +731,33 @@ class TestScriptPy3(JitTestCase):
|
||||
|
||||
@torch.jit.interface
|
||||
class OneTwoModule(nn.Module):
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
def two(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
class FooMod(nn.Module):
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y
|
||||
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
def two(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 2 * x
|
||||
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.one(self.two(x), x)
|
||||
|
||||
class BarMod(nn.Module):
|
||||
def one(self, x, y):
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * y
|
||||
|
||||
def two(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
def two(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 2 / x
|
||||
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.two(self.one(x, x))
|
||||
|
||||
class M(nn.Module):
|
||||
@ -778,8 +767,7 @@ class TestScriptPy3(JitTestCase):
|
||||
super(M, self).__init__()
|
||||
self.sub = BarMod()
|
||||
|
||||
def forward(self, x):
|
||||
# type: (Tensor) -> Tensor
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.sub.forward(x)
|
||||
|
||||
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
|
||||
|
@ -1,28 +1,26 @@
|
||||
from test_jit import JitTestCase
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
class TestScript(JitTestCase):
|
||||
def test_str_ops(self):
|
||||
def test_str_is(s):
|
||||
# type: (str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
|
||||
def test_str_is(s: str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
|
||||
return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
|
||||
s.isalnum(), s.isalpha(), s.isdecimal(), s.isnumeric(), \
|
||||
s.isidentifier(), s.istitle(), s.isprintable()
|
||||
|
||||
def test_str_to(s):
|
||||
# type: (str) -> Tuple[str, str, str, str, str]
|
||||
def test_str_to(s: str) -> Tuple[str, str, str, str, str]:
|
||||
return s.upper(), s.lower(), s.capitalize(), s.title(), s.swapcase()
|
||||
|
||||
def test_str_strip(s):
|
||||
# type: (str) -> Tuple[str, str, str]
|
||||
def test_str_strip(s: str) -> Tuple[str, str, str]:
|
||||
return (
|
||||
s.lstrip(),
|
||||
s.rstrip(),
|
||||
s.strip(),
|
||||
)
|
||||
|
||||
def test_str_strip_char_set(s, char_set):
|
||||
# type: (str, str) -> Tuple[str, str, str]
|
||||
def test_str_strip_char_set(s: str, char_set: str) -> Tuple[str, str, str]:
|
||||
return (
|
||||
s.lstrip(char_set),
|
||||
s.rstrip(char_set),
|
||||
@ -34,44 +32,34 @@ class TestScript(JitTestCase):
|
||||
"more strings with spaces", "Titular Strings", "\x0acan'tprintthis",
|
||||
"spaces at the end ", " begin"]
|
||||
|
||||
def test_str_center(i, s):
|
||||
# type: (int, str) -> str
|
||||
def test_str_center(i: int, s: str) -> str:
|
||||
return s.center(i)
|
||||
|
||||
def test_str_center_fc(i, s):
|
||||
# type: (int, str) -> str
|
||||
def test_str_center_fc(i: int, s: str) -> str:
|
||||
return s.center(i, '*')
|
||||
|
||||
def test_str_center_error(s):
|
||||
# type: (str) -> str
|
||||
def test_str_center_error(s: str) -> str:
|
||||
return s.center(10, '**')
|
||||
|
||||
def test_ljust(s, i):
|
||||
# type: (str, int) -> str
|
||||
def test_ljust(s: str, i: int) -> str:
|
||||
return s.ljust(i)
|
||||
|
||||
def test_ljust_fc(s, i, fc):
|
||||
# type: (str, int, str) -> str
|
||||
def test_ljust_fc(s: str, i: int, fc: str) -> str:
|
||||
return s.ljust(i, fc)
|
||||
|
||||
def test_ljust_fc_err(s):
|
||||
# type: (str) -> str
|
||||
def test_ljust_fc_err(s: str) -> str:
|
||||
return s.ljust(10, '**')
|
||||
|
||||
def test_rjust(s, i):
|
||||
# type: (str, int) -> str
|
||||
def test_rjust(s: str, i: int) -> str:
|
||||
return s.rjust(i)
|
||||
|
||||
def test_rjust_fc(s, i, fc):
|
||||
# type: (str, int, str) -> str
|
||||
def test_rjust_fc(s: str, i: int, fc: str) -> str:
|
||||
return s.rjust(i, fc)
|
||||
|
||||
def test_rjust_fc_err(s):
|
||||
# type: (str) -> str
|
||||
def test_rjust_fc_err(s: str) -> str:
|
||||
return s.rjust(10, '**')
|
||||
|
||||
def test_zfill(s, i):
|
||||
# type: (str, int) -> str
|
||||
def test_zfill(s: str, i: int) -> str:
|
||||
return s.zfill(i)
|
||||
|
||||
for input in inputs:
|
||||
@ -93,8 +81,7 @@ class TestScript(JitTestCase):
|
||||
test_str_center_error("error")
|
||||
test_ljust("error")
|
||||
|
||||
def test_count():
|
||||
# type: () -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]
|
||||
def test_count() -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]:
|
||||
return (
|
||||
"hello".count("h"),
|
||||
"hello".count("h", 0, 1),
|
||||
@ -111,8 +98,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_count, ())
|
||||
|
||||
def test_endswith():
|
||||
# type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
|
||||
def test_endswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
|
||||
return (
|
||||
"hello".endswith("lo"),
|
||||
"hello".endswith("lo", 0),
|
||||
@ -131,8 +117,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_endswith, ())
|
||||
|
||||
def test_startswith():
|
||||
# type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]
|
||||
def test_startswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
|
||||
return (
|
||||
"hello".startswith("lo"),
|
||||
"hello".startswith("lo", 0),
|
||||
@ -151,8 +136,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_startswith, ())
|
||||
|
||||
def test_expandtabs():
|
||||
# type: () -> Tuple[str, str, str, str, str, str]
|
||||
def test_expandtabs() -> Tuple[str, str, str, str, str, str]:
|
||||
return (
|
||||
'xyz\t82345\tabc'.expandtabs(),
|
||||
'xyz\t32345\tabc'.expandtabs(3),
|
||||
@ -163,8 +147,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_expandtabs, ())
|
||||
|
||||
def test_rfind():
|
||||
# type: () -> Tuple[int, int, int, int, int, int, int, int, int]
|
||||
def test_rfind() -> Tuple[int, int, int, int, int, int, int, int, int]:
|
||||
return (
|
||||
"hello123abc".rfind("llo"),
|
||||
"hello123abc".rfind("12"),
|
||||
@ -178,8 +161,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_rfind, ())
|
||||
|
||||
def test_find():
|
||||
# type: () -> Tuple[int, int, int, int, int, int, int, int, int]
|
||||
def test_find() -> Tuple[int, int, int, int, int, int, int, int, int]:
|
||||
return (
|
||||
"hello123abc".find("llo"),
|
||||
"hello123abc".find("12"),
|
||||
@ -193,8 +175,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_find, ())
|
||||
|
||||
def test_index():
|
||||
# type: () -> Tuple[int, int, int, int, int, int]
|
||||
def test_index() -> Tuple[int, int, int, int, int, int]:
|
||||
return (
|
||||
"hello123abc".index("llo"),
|
||||
"hello123abc".index("12"),
|
||||
@ -205,8 +186,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_index, ())
|
||||
|
||||
def test_rindex():
|
||||
# type: () -> Tuple[int, int, int, int, int, int]
|
||||
def test_rindex() -> Tuple[int, int, int, int, int, int]:
|
||||
return (
|
||||
"hello123abc".rindex("llo"),
|
||||
"hello123abc".rindex("12"),
|
||||
@ -217,8 +197,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_rindex, ())
|
||||
|
||||
def test_replace():
|
||||
# type: () -> Tuple[str, str, str, str, str, str, str]
|
||||
def test_replace() -> Tuple[str, str, str, str, str, str, str]:
|
||||
return (
|
||||
"hello123abc".replace("llo", "sdf"),
|
||||
"ff".replace("f", "ff"),
|
||||
@ -230,11 +209,9 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_replace, ())
|
||||
|
||||
def test_partition():
|
||||
"""
|
||||
type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str],
|
||||
Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]]
|
||||
"""
|
||||
def test_partition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
|
||||
Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
|
||||
Tuple[str, str, str]]:
|
||||
return (
|
||||
"hello123abc".partition("llo"),
|
||||
"ff".partition("f"),
|
||||
@ -246,11 +223,9 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_partition, ())
|
||||
|
||||
def test_rpartition():
|
||||
"""
|
||||
type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str],
|
||||
Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]]
|
||||
"""
|
||||
def test_rpartition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
|
||||
Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
|
||||
Tuple[str, str, str]]:
|
||||
return (
|
||||
"hello123abc".rpartition("llo"),
|
||||
"ff".rpartition("f"),
|
||||
@ -262,11 +237,8 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_rpartition, ())
|
||||
|
||||
def test_split():
|
||||
"""
|
||||
type: () -> Tuple[List[str], List[str], List[str], List[str], List[str],
|
||||
List[str], List[str], List[str], List[str], List[str], List[str]]
|
||||
"""
|
||||
def test_split() -> Tuple[List[str], List[str], List[str], List[str], List[str],
|
||||
List[str], List[str], List[str], List[str], List[str], List[str]]:
|
||||
return (
|
||||
"a a a a a".split(),
|
||||
"a a a a a".split(),
|
||||
@ -290,8 +262,8 @@ class TestScript(JitTestCase):
|
||||
self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception,
|
||||
"empty separator")
|
||||
|
||||
def test_rsplit():
|
||||
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
|
||||
def test_rsplit() -> Tuple[List[str], List[str], List[str], List[str], List[str],
|
||||
List[str], List[str], List[str], List[str]]:
|
||||
return (
|
||||
"a a a a a".rsplit(),
|
||||
" a a a a a ".rsplit(" "),
|
||||
@ -305,8 +277,8 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_rsplit, ())
|
||||
|
||||
def test_splitlines():
|
||||
# type: () -> Tuple[ List[str], List[str], List[str], List[str], List[str], List[str] ]
|
||||
def test_splitlines() -> Tuple[List[str], List[str], List[str], List[str],
|
||||
List[str], List[str]]:
|
||||
return (
|
||||
"hello\ntest".splitlines(),
|
||||
"hello\n\ntest\n".splitlines(),
|
||||
@ -317,8 +289,7 @@ class TestScript(JitTestCase):
|
||||
)
|
||||
self.checkScript(test_splitlines, ())
|
||||
|
||||
def test_str_cmp(a, b):
|
||||
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
|
||||
def test_str_cmp(a: str, b: str) -> Tuple[bool, bool, bool, bool, bool, bool]:
|
||||
return a != b, a == b, a < b, a > b, a <= b, a >= b
|
||||
|
||||
for i in range(len(inputs) - 1):
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
class StaticRuntime:
|
||||
def __init__(self, scripted):
|
||||
@ -27,8 +28,7 @@ class StaticRuntime:
|
||||
)
|
||||
|
||||
|
||||
def linear_shim(input, weight, bias=None):
|
||||
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
||||
def linear_shim(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
output = input.matmul(weight.t())
|
||||
if bias is not None:
|
||||
output += bias
|
||||
@ -116,7 +116,7 @@ def loop_graph(a, b, iters : int):
|
||||
def output_graph(a, b, c, iters : int):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
k = a + b * c + s
|
||||
d : Dict[int, Tensor] = {}
|
||||
d : Dict[int, torch.Tensor] = {}
|
||||
for i in range(iters):
|
||||
d[i] = k + i
|
||||
return d
|
||||
|
@ -1163,13 +1163,11 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
|
||||
def test_scalar(self):
|
||||
@torch.jit.script
|
||||
def test_float(x, y, z, a, b):
|
||||
# type: (Tensor, Tensor, Tensor, float, float) -> Tensor
|
||||
def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
|
||||
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
|
||||
|
||||
@torch.jit.script
|
||||
def test_int(x, y, z, a, b):
|
||||
# type: (Tensor, Tensor, Tensor, int, int) -> Tensor
|
||||
def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor:
|
||||
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
|
||||
|
||||
for test in (test_float, test_int):
|
||||
@ -1186,8 +1184,7 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
|
||||
def test_loop(self):
|
||||
@torch.jit.script
|
||||
def test(x, y, z):
|
||||
# type: (Tensor, Tensor, int) -> Tensor
|
||||
def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor:
|
||||
b = y
|
||||
for i in range(0, z):
|
||||
a = x + y
|
||||
|
Reference in New Issue
Block a user