use flake8-mypy (#17721)

Summary:
Use flake8 installed with mypy checks so that our linter matches fbcode. Mypy type errors also provide valuable signal
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17721

Differential Revision: D14357778

Pulled By: eellison

fbshipit-source-id: d8c9ea3fe3b5f550c3b70fe259e0eabf95e4c92d
This commit is contained in:
Elias Ellison
2019-03-07 09:12:35 -08:00
committed by Facebook Github Bot
parent 1d522598fb
commit 561037aef8
6 changed files with 60 additions and 55 deletions

View File

@ -1,4 +1,4 @@
[flake8]
max-line-length = 120
ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include

View File

@ -27,5 +27,5 @@ matrix:
include:
env: LINT_CHECK
python: "2.7"
install: pip install flake8
install: pip install flake8-mypy
script: flake8

View File

@ -29,7 +29,7 @@ matrix:
python: "3.7"
dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069)
sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069)
install: pip install flake8
install: pip install flake8-mypy
script: flake8
- name: "MyPy typecheck"
python: "3.6"

View File

@ -5,11 +5,13 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel as dp
import torch.optim as optim
import torch.cuda
import torch.jit.quantized
from contextlib import contextmanager
from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
from torch.nn import Module
from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
@ -44,9 +46,11 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \
ListType, StringType, DictType
from copy import deepcopy
import random
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
from torch.jit.frontend import NotSupportedError
from torch.jit import BatchTensor
from torch import Tensor
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
# For testing truediv in python 2
from test_module.future_div import div_int_future, div_float_future
@ -96,7 +100,7 @@ if WINDOWS:
finally:
os.unlink(f.name)
else:
@contextmanager
@contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name
@ -2262,7 +2266,7 @@ class TestJit(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
@torch.jit.script
def hints_bad_types(x, a=10, b=0.5):
def hints_bad_types(x, a=10, b=0.5): # noqa: T484
# type: (Tensor, float, int) -> Tensor
return x + a + b
@ -3113,7 +3117,7 @@ class TestScript(JitTestCase):
def sum_list(a):
# type: (int) -> int
sum = 0
for i in a:
for i in a: # noqa: T484
sum += i
return sum
@ -4727,23 +4731,23 @@ a")
x = 1
else:
x = torch.jit._unwrap_optional(x)
return x
return x # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def or_error(x, y):
# type: (Optional[int], Optional[int]) -> int
# type: (Optional[int], Optional[int]) -> None
if x is None or y is None:
print(x + y)
print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def and_error(x, y):
# type: (Optional[int], Optional[int]) -> int
# type: (Optional[int], Optional[int]) -> None
if x is None and y is None:
pass
else:
print(x + y)
print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
@ -4751,7 +4755,7 @@ a")
# type: (Optional[int]) -> None
x_none = x is not None
if x_none:
print(x + 1)
print(x + 1) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
@ -4759,7 +4763,7 @@ a")
# type: (Optional[int], Optional[int]) -> None
x_none = x is not None
if y is not None and x_none:
print(x + y)
print(x + y) # noqa: T484
def test_while_write_outer_then_read(self):
def func(a, b):
@ -5057,10 +5061,11 @@ a")
self.checkScript(multiple_returns, [a], optimize=True)
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
@torch.jit.script
torch.jit.CompilationUnit('''
def no_return_bad_annotation(a):
# type: (Tensor) -> Tensor
a + 1
''')
def test_error(self):
@torch.jit.script
@ -5654,8 +5659,6 @@ a")
hiddens = hx
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
from typing import Tuple
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super(ScriptWrapper, self).__init__()
@ -6650,7 +6653,7 @@ a")
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
def foo():
# type: () -> Tensor
return ((3, 4),)
return ((3, 4),) # noqa: T484
@torch.jit.script
def bar():
@ -6769,7 +6772,7 @@ a")
if x:
y = [1]
else:
y = [None]
y = [None] # noqa: T484
return y[0]
@torch.jit.script
@ -6815,18 +6818,18 @@ a")
print(int_fn((1, 1, 1)))
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
@torch.jit.script
@torch.jit.script # noqa: T484
def fn(x):
# type: (BroadcastingListx[int]) -> List[int]
# type: (BroadcastingListx[int]) -> List[int] # noqa: T484
return x
# TODO: the type comment in this seems to trip up flake8 for some reason
# even though we have a noqa comment. Figure out why
# using CU so that flake8 error on int[2] is not raised (noqa not working)
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
@torch.jit.script
def nested(x, y):
# type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484
return x
cu = torch.jit.CompilationUnit('''
def nested(x, y):
# type: (int, Tuple[int, int[2]]) -> List[int]
return x # noqa: T484
''')
def test_ntuple_builtins(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
@ -8349,7 +8352,7 @@ a")
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
def somefunc():
# type: () -> Tuple[Tuple[Tensor, Tensor]]
return torch.zeros(3, 4), torch.zeros(4, 5)
return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
@torch.jit.script
def wrong_return_type():
@ -9029,7 +9032,7 @@ a")
def test(x):
# type: (Optional[int]) -> int
x = torch.jit._unwrap_optional(x)
x = x + x
x = x + x # noqa: T484
return x
self.checkScript(test, (3,))
@ -9082,14 +9085,14 @@ a")
@torch.jit.script
def return_tup(x):
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
return x, x
return x, x # noqa: T484
def test_annotated_script_fn_arg_mismatch(self):
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
@torch.jit.script
def tuple_arg(x):
# type: (Tuple[Tensor, Tensor]) -> Tensor
return x + 1
return x + 1 # noqa: T484
def test_script_non_tensor_args_outputs(self):
@torch.jit.script
@ -13122,11 +13125,11 @@ class TestAsync(JitTestCase):
self.assertEqual(y, y_hat)
def test_async_script_capture(self):
class Module(torch.jit.ScriptModule):
class Mod(torch.jit.ScriptModule):
__constants__ = ['const']
def __init__(self):
super(Module, self).__init__(False)
super(Mod, self).__init__(False)
self.const = 42
self.param = nn.Parameter(torch.randn(2, 2))
@ -13144,7 +13147,7 @@ class TestAsync(JitTestCase):
x1 = torch.rand(3, 4)
x2 = torch.rand(5, 6)
m = Module()
m = Mod()
y, y_hat = m.wait_script(x1, x2)
self.assertEqual(y, y_hat)
@ -13244,9 +13247,9 @@ class TestAsync(JitTestCase):
def forward(self, x):
return (torch.neg(x), x)
class Module(torch.jit.ScriptModule):
class Mod(torch.jit.ScriptModule):
def __init__(self):
super(Module, self).__init__(False)
super(Mod, self).__init__(False)
x = torch.rand(3, 3)
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
@ -13266,10 +13269,10 @@ class TestAsync(JitTestCase):
# return a nested structure of tensors
return (tensor_list, tensor_tuple, tensor_tuple[1])
class Tuple(nn.Module):
class TupleCl(nn.Module):
def __init__(self):
super(Tuple, self).__init__()
self.module = Module()
super(TupleCl, self).__init__()
self.module = Mod()
def forward(self, x):
z = torch.neg(x)
@ -13278,7 +13281,7 @@ class TestAsync(JitTestCase):
return tuple(list)
x = torch.rand(3, 3)
module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
# Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
@ -13632,16 +13635,16 @@ class TestClassType(JitTestCase):
@torch.jit.script
class FooTest:
def __init__(self, x):
# type: (int)
# type: (int) -> None
self.foo = x
def incFooTest(self, y):
# type: (int)
# type: (int) -> None
self.foo = self.foo + y
@torch.jit.script
def fn(x):
# type: (int)
# type: (int) -> int
foo = FooTest(x)
foo.incFooTest(2)
return foo.foo
@ -13689,7 +13692,7 @@ class TestClassType(JitTestCase):
@torch.jit.script
class FooTest:
def __init__(self, x):
# type: (bool)
# type: (bool) -> None
self.foo = x
@torch.jit.script
@ -13718,7 +13721,7 @@ class TestClassType(JitTestCase):
@torch.jit.script
def fn(foo):
# type: (FooTest)
# type: (FooTest) -> Tensor
return foo.attr
@torch.jit.script

View File

@ -9,24 +9,24 @@ import inspect
from torch._six import builtins
# Tracks standalone weak script functions
compiled_weak_fns = weakref.WeakKeyDictionary()
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
# Tracks which methods should be converted to strong methods
weak_script_methods = weakref.WeakKeyDictionary()
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
# Converted modules and their corresponding WeakScriptModuleProxy objects
weak_modules = weakref.WeakKeyDictionary()
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
# Types that have been declared as weak modules
weak_types = weakref.WeakKeyDictionary()
weak_types = weakref.WeakKeyDictionary() # noqa: T484
# Wrapper functions that can call either of 2 functions depending on a boolean
# argument
boolean_dispatched = weakref.WeakKeyDictionary()
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
# Python Op functions that should be ignored by the compiler. These will be replaced
# with an operator that always throws an error
ignored_fns = weakref.WeakSet()
ignored_fns = weakref.WeakSet() # noqa: T484
COMPILATION_PENDING = object()
COMPILED = object()
@ -223,9 +223,9 @@ except ImportError:
def __getitem__(self, types):
return DictInstance(types)
Tuple = TupleCls()
List = ListCls()
Dict = DictCls()
Tuple = TupleCls() # noqa: T484
List = ListCls() # noqa: T484
Dict = DictCls() # noqa: T484
def is_tuple(ann):
return isinstance(ann, TupleInstance)

View File

@ -1,7 +1,9 @@
import torch
import copy
import numbers
from typing import Tuple
from typing import Tuple, Optional
from torch import Tensor
from torch.jit import ScriptModule
from torch.nn.utils.rnn import PackedSequence
from torch.nn import _VF