use flake8-mypy (#17721)

Summary:
Use flake8 installed with mypy checks so that our linter matches fbcode. Mypy type errors also provide valuable signal
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17721

Differential Revision: D14357778

Pulled By: eellison

fbshipit-source-id: d8c9ea3fe3b5f550c3b70fe259e0eabf95e4c92d
This commit is contained in:
Elias Ellison
2019-03-07 09:12:35 -08:00
committed by Facebook Github Bot
parent 1d522598fb
commit 561037aef8
6 changed files with 60 additions and 55 deletions

View File

@ -1,4 +1,4 @@
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504 ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include

View File

@ -27,5 +27,5 @@ matrix:
include: include:
env: LINT_CHECK env: LINT_CHECK
python: "2.7" python: "2.7"
install: pip install flake8 install: pip install flake8-mypy
script: flake8 script: flake8

View File

@ -29,7 +29,7 @@ matrix:
python: "3.7" python: "3.7"
dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069) dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069)
sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069) sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069)
install: pip install flake8 install: pip install flake8-mypy
script: flake8 script: flake8
- name: "MyPy typecheck" - name: "MyPy typecheck"
python: "3.6" python: "3.6"

View File

@ -5,11 +5,13 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.parallel as dp import torch.nn.parallel as dp
import torch.optim as optim import torch.optim as optim
import torch.cuda
import torch.jit.quantized import torch.jit.quantized
from contextlib import contextmanager from contextlib import contextmanager
from itertools import product, chain from itertools import product, chain
import torch.jit.frontend import torch.jit.frontend
from torch.autograd import Variable, Function from torch.autograd import Variable, Function
from torch.nn import Module
from torch.autograd.function import traceable from torch.autograd.function import traceable
from torch.testing import assert_allclose from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes from torch.onnx import OperatorExportTypes
@ -44,9 +46,11 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \
ListType, StringType, DictType ListType, StringType, DictType
from copy import deepcopy from copy import deepcopy
import random import random
from typing import List, Dict, Optional from typing import List, Dict, Optional, Tuple
from torch.jit.frontend import NotSupportedError from torch.jit.frontend import NotSupportedError
from torch.jit import BatchTensor from torch.jit import BatchTensor
from torch import Tensor
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
# For testing truediv in python 2 # For testing truediv in python 2
from test_module.future_div import div_int_future, div_float_future from test_module.future_div import div_int_future, div_float_future
@ -96,7 +100,7 @@ if WINDOWS:
finally: finally:
os.unlink(f.name) os.unlink(f.name)
else: else:
@contextmanager @contextmanager # noqa: T484
def TemporaryFileName(): def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
yield f.name yield f.name
@ -2262,7 +2266,7 @@ class TestJit(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Expected a default value"): with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
@torch.jit.script @torch.jit.script
def hints_bad_types(x, a=10, b=0.5): def hints_bad_types(x, a=10, b=0.5): # noqa: T484
# type: (Tensor, float, int) -> Tensor # type: (Tensor, float, int) -> Tensor
return x + a + b return x + a + b
@ -3113,7 +3117,7 @@ class TestScript(JitTestCase):
def sum_list(a): def sum_list(a):
# type: (int) -> int # type: (int) -> int
sum = 0 sum = 0
for i in a: for i in a: # noqa: T484
sum += i sum += i
return sum return sum
@ -4727,23 +4731,23 @@ a")
x = 1 x = 1
else: else:
x = torch.jit._unwrap_optional(x) x = torch.jit._unwrap_optional(x)
return x return x # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script @torch.jit.script
def or_error(x, y): def or_error(x, y):
# type: (Optional[int], Optional[int]) -> int # type: (Optional[int], Optional[int]) -> None
if x is None or y is None: if x is None or y is None:
print(x + y) print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script @torch.jit.script
def and_error(x, y): def and_error(x, y):
# type: (Optional[int], Optional[int]) -> int # type: (Optional[int], Optional[int]) -> None
if x is None and y is None: if x is None and y is None:
pass pass
else: else:
print(x + y) print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script @torch.jit.script
@ -4751,7 +4755,7 @@ a")
# type: (Optional[int]) -> None # type: (Optional[int]) -> None
x_none = x is not None x_none = x is not None
if x_none: if x_none:
print(x + 1) print(x + 1) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script @torch.jit.script
@ -4759,7 +4763,7 @@ a")
# type: (Optional[int], Optional[int]) -> None # type: (Optional[int], Optional[int]) -> None
x_none = x is not None x_none = x is not None
if y is not None and x_none: if y is not None and x_none:
print(x + y) print(x + y) # noqa: T484
def test_while_write_outer_then_read(self): def test_while_write_outer_then_read(self):
def func(a, b): def func(a, b):
@ -5057,10 +5061,11 @@ a")
self.checkScript(multiple_returns, [a], optimize=True) self.checkScript(multiple_returns, [a], optimize=True)
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"): with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
@torch.jit.script torch.jit.CompilationUnit('''
def no_return_bad_annotation(a): def no_return_bad_annotation(a):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
a + 1 a + 1
''')
def test_error(self): def test_error(self):
@torch.jit.script @torch.jit.script
@ -5654,8 +5659,6 @@ a")
hiddens = hx hiddens = hx
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
from typing import Tuple
class ScriptWrapper(torch.jit.ScriptModule): class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell): def __init__(self, cell):
super(ScriptWrapper, self).__init__() super(ScriptWrapper, self).__init__()
@ -6650,7 +6653,7 @@ a")
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"): with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
def foo(): def foo():
# type: () -> Tensor # type: () -> Tensor
return ((3, 4),) return ((3, 4),) # noqa: T484
@torch.jit.script @torch.jit.script
def bar(): def bar():
@ -6769,7 +6772,7 @@ a")
if x: if x:
y = [1] y = [1]
else: else:
y = [None] y = [None] # noqa: T484
return y[0] return y[0]
@torch.jit.script @torch.jit.script
@ -6815,18 +6818,18 @@ a")
print(int_fn((1, 1, 1))) print(int_fn((1, 1, 1)))
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"): with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
@torch.jit.script @torch.jit.script # noqa: T484
def fn(x): def fn(x):
# type: (BroadcastingListx[int]) -> List[int] # type: (BroadcastingListx[int]) -> List[int] # noqa: T484
return x return x
# TODO: the type comment in this seems to trip up flake8 for some reason # using CU so that flake8 error on int[2] is not raised (noqa not working)
# even though we have a noqa comment. Figure out why
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"): with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
@torch.jit.script cu = torch.jit.CompilationUnit('''
def nested(x, y): def nested(x, y):
# type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484 # type: (int, Tuple[int, int[2]]) -> List[int]
return x return x # noqa: T484
''')
def test_ntuple_builtins(self): def test_ntuple_builtins(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
@ -8349,7 +8352,7 @@ a")
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'): with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
def somefunc(): def somefunc():
# type: () -> Tuple[Tuple[Tensor, Tensor]] # type: () -> Tuple[Tuple[Tensor, Tensor]]
return torch.zeros(3, 4), torch.zeros(4, 5) return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
@torch.jit.script @torch.jit.script
def wrong_return_type(): def wrong_return_type():
@ -9029,7 +9032,7 @@ a")
def test(x): def test(x):
# type: (Optional[int]) -> int # type: (Optional[int]) -> int
x = torch.jit._unwrap_optional(x) x = torch.jit._unwrap_optional(x)
x = x + x x = x + x # noqa: T484
return x return x
self.checkScript(test, (3,)) self.checkScript(test, (3,))
@ -9082,14 +9085,14 @@ a")
@torch.jit.script @torch.jit.script
def return_tup(x): def return_tup(x):
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor] # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
return x, x return x, x # noqa: T484
def test_annotated_script_fn_arg_mismatch(self): def test_annotated_script_fn_arg_mismatch(self):
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"): with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
@torch.jit.script @torch.jit.script
def tuple_arg(x): def tuple_arg(x):
# type: (Tuple[Tensor, Tensor]) -> Tensor # type: (Tuple[Tensor, Tensor]) -> Tensor
return x + 1 return x + 1 # noqa: T484
def test_script_non_tensor_args_outputs(self): def test_script_non_tensor_args_outputs(self):
@torch.jit.script @torch.jit.script
@ -13122,11 +13125,11 @@ class TestAsync(JitTestCase):
self.assertEqual(y, y_hat) self.assertEqual(y, y_hat)
def test_async_script_capture(self): def test_async_script_capture(self):
class Module(torch.jit.ScriptModule): class Mod(torch.jit.ScriptModule):
__constants__ = ['const'] __constants__ = ['const']
def __init__(self): def __init__(self):
super(Module, self).__init__(False) super(Mod, self).__init__(False)
self.const = 42 self.const = 42
self.param = nn.Parameter(torch.randn(2, 2)) self.param = nn.Parameter(torch.randn(2, 2))
@ -13144,7 +13147,7 @@ class TestAsync(JitTestCase):
x1 = torch.rand(3, 4) x1 = torch.rand(3, 4)
x2 = torch.rand(5, 6) x2 = torch.rand(5, 6)
m = Module() m = Mod()
y, y_hat = m.wait_script(x1, x2) y, y_hat = m.wait_script(x1, x2)
self.assertEqual(y, y_hat) self.assertEqual(y, y_hat)
@ -13244,9 +13247,9 @@ class TestAsync(JitTestCase):
def forward(self, x): def forward(self, x):
return (torch.neg(x), x) return (torch.neg(x), x)
class Module(torch.jit.ScriptModule): class Mod(torch.jit.ScriptModule):
def __init__(self): def __init__(self):
super(Module, self).__init__(False) super(Mod, self).__init__(False)
x = torch.rand(3, 3) x = torch.rand(3, 3)
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
@ -13266,10 +13269,10 @@ class TestAsync(JitTestCase):
# return a nested structure of tensors # return a nested structure of tensors
return (tensor_list, tensor_tuple, tensor_tuple[1]) return (tensor_list, tensor_tuple, tensor_tuple[1])
class Tuple(nn.Module): class TupleCl(nn.Module):
def __init__(self): def __init__(self):
super(Tuple, self).__init__() super(TupleCl, self).__init__()
self.module = Module() self.module = Mod()
def forward(self, x): def forward(self, x):
z = torch.neg(x) z = torch.neg(x)
@ -13278,7 +13281,7 @@ class TestAsync(JitTestCase):
return tuple(list) return tuple(list)
x = torch.rand(3, 3) x = torch.rand(3, 3)
module = torch.jit.trace(Tuple(), (x), _force_outplace=True) module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
# Make sure we have forks # Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2) self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
@ -13632,16 +13635,16 @@ class TestClassType(JitTestCase):
@torch.jit.script @torch.jit.script
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
# type: (int) # type: (int) -> None
self.foo = x self.foo = x
def incFooTest(self, y): def incFooTest(self, y):
# type: (int) # type: (int) -> None
self.foo = self.foo + y self.foo = self.foo + y
@torch.jit.script @torch.jit.script
def fn(x): def fn(x):
# type: (int) # type: (int) -> int
foo = FooTest(x) foo = FooTest(x)
foo.incFooTest(2) foo.incFooTest(2)
return foo.foo return foo.foo
@ -13689,7 +13692,7 @@ class TestClassType(JitTestCase):
@torch.jit.script @torch.jit.script
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
# type: (bool) # type: (bool) -> None
self.foo = x self.foo = x
@torch.jit.script @torch.jit.script
@ -13718,7 +13721,7 @@ class TestClassType(JitTestCase):
@torch.jit.script @torch.jit.script
def fn(foo): def fn(foo):
# type: (FooTest) # type: (FooTest) -> Tensor
return foo.attr return foo.attr
@torch.jit.script @torch.jit.script

View File

@ -9,24 +9,24 @@ import inspect
from torch._six import builtins from torch._six import builtins
# Tracks standalone weak script functions # Tracks standalone weak script functions
compiled_weak_fns = weakref.WeakKeyDictionary() compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
# Tracks which methods should be converted to strong methods # Tracks which methods should be converted to strong methods
weak_script_methods = weakref.WeakKeyDictionary() weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
# Converted modules and their corresponding WeakScriptModuleProxy objects # Converted modules and their corresponding WeakScriptModuleProxy objects
weak_modules = weakref.WeakKeyDictionary() weak_modules = weakref.WeakKeyDictionary() # noqa: T484
# Types that have been declared as weak modules # Types that have been declared as weak modules
weak_types = weakref.WeakKeyDictionary() weak_types = weakref.WeakKeyDictionary() # noqa: T484
# Wrapper functions that can call either of 2 functions depending on a boolean # Wrapper functions that can call either of 2 functions depending on a boolean
# argument # argument
boolean_dispatched = weakref.WeakKeyDictionary() boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
# Python Op functions that should be ignored by the compiler. These will be replaced # Python Op functions that should be ignored by the compiler. These will be replaced
# with an operator that always throws an error # with an operator that always throws an error
ignored_fns = weakref.WeakSet() ignored_fns = weakref.WeakSet() # noqa: T484
COMPILATION_PENDING = object() COMPILATION_PENDING = object()
COMPILED = object() COMPILED = object()
@ -223,9 +223,9 @@ except ImportError:
def __getitem__(self, types): def __getitem__(self, types):
return DictInstance(types) return DictInstance(types)
Tuple = TupleCls() Tuple = TupleCls() # noqa: T484
List = ListCls() List = ListCls() # noqa: T484
Dict = DictCls() Dict = DictCls() # noqa: T484
def is_tuple(ann): def is_tuple(ann):
return isinstance(ann, TupleInstance) return isinstance(ann, TupleInstance)

View File

@ -1,7 +1,9 @@
import torch import torch
import copy import copy
import numbers import numbers
from typing import Tuple from typing import Tuple, Optional
from torch import Tensor
from torch.jit import ScriptModule
from torch.nn.utils.rnn import PackedSequence from torch.nn.utils.rnn import PackedSequence
from torch.nn import _VF from torch.nn import _VF