Type annotations in test/jit (#50293)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50293

Switching to type annotations for improved safety and import tracking.

Test Plan: Sandcastle tests

Reviewed By: xush6528

Differential Revision: D25853949

fbshipit-source-id: fb873587bb521a0a55021ee4d34d1b05ea8f000d
This commit is contained in:
Richard Barnes
2021-01-12 16:45:16 -08:00
committed by Facebook GitHub Bot
parent 4c97ef8d77
commit 8c25b9701b
7 changed files with 206 additions and 390 deletions

View File

@ -2,10 +2,11 @@ import os
import sys
import typing
import typing_extensions
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.testing import FileCheck
from collections import OrderedDict
@ -284,8 +285,7 @@ class TestRecursiveScript(JitTestCase):
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
def test_class_compile(self):
def other_fn(a, b):
# type: (int, Tensor) -> Tensor
def other_fn(a: int, b: Tensor) -> Tensor:
return a * b
class B(object):
@ -307,8 +307,7 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(N(), (torch.randn(2, 2),))
def test_error_stack(self):
def d(x):
# type: (int) -> int
def d(x: int) -> int:
return x + 10
def c(x):
@ -331,8 +330,7 @@ class TestRecursiveScript(JitTestCase):
checker.run(str(e))
def test_error_stack_module(self):
def d(x):
# type: (int) -> int
def d(x: int) -> int:
return x + 10
def c(x):
@ -565,8 +563,7 @@ class TestRecursiveScript(JitTestCase):
self.a = 4
self.inner = Inner2()
def __setstate__(self, obj):
# type: (Tuple[int, Inner2]) -> None
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
a, inner = obj
self.a = a
self.inner = inner