mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
use flake8-mypy (#17721)
Summary: Use flake8 installed with mypy checks so that our linter matches fbcode. Mypy type errors also provide valuable signal Pull Request resolved: https://github.com/pytorch/pytorch/pull/17721 Differential Revision: D14357778 Pulled By: eellison fbshipit-source-id: d8c9ea3fe3b5f550c3b70fe259e0eabf95e4c92d
This commit is contained in:
committed by
Facebook Github Bot
parent
1d522598fb
commit
561037aef8
2
.flake8
2
.flake8
@ -1,4 +1,4 @@
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
|
||||
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install
|
||||
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include
|
||||
|
@ -27,5 +27,5 @@ matrix:
|
||||
include:
|
||||
env: LINT_CHECK
|
||||
python: "2.7"
|
||||
install: pip install flake8
|
||||
install: pip install flake8-mypy
|
||||
script: flake8
|
||||
|
@ -29,7 +29,7 @@ matrix:
|
||||
python: "3.7"
|
||||
dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069)
|
||||
sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069)
|
||||
install: pip install flake8
|
||||
install: pip install flake8-mypy
|
||||
script: flake8
|
||||
- name: "MyPy typecheck"
|
||||
python: "3.6"
|
||||
|
@ -5,11 +5,13 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.parallel as dp
|
||||
import torch.optim as optim
|
||||
import torch.cuda
|
||||
import torch.jit.quantized
|
||||
from contextlib import contextmanager
|
||||
from itertools import product, chain
|
||||
import torch.jit.frontend
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.nn import Module
|
||||
from torch.autograd.function import traceable
|
||||
from torch.testing import assert_allclose
|
||||
from torch.onnx import OperatorExportTypes
|
||||
@ -44,9 +46,11 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \
|
||||
ListType, StringType, DictType
|
||||
from copy import deepcopy
|
||||
import random
|
||||
from typing import List, Dict, Optional
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from torch.jit.frontend import NotSupportedError
|
||||
from torch.jit import BatchTensor
|
||||
from torch import Tensor
|
||||
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
|
||||
|
||||
# For testing truediv in python 2
|
||||
from test_module.future_div import div_int_future, div_float_future
|
||||
@ -96,7 +100,7 @@ if WINDOWS:
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
else:
|
||||
@contextmanager
|
||||
@contextmanager # noqa: T484
|
||||
def TemporaryFileName():
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
yield f.name
|
||||
@ -2262,7 +2266,7 @@ class TestJit(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
|
||||
|
||||
@torch.jit.script
|
||||
def hints_bad_types(x, a=10, b=0.5):
|
||||
def hints_bad_types(x, a=10, b=0.5): # noqa: T484
|
||||
# type: (Tensor, float, int) -> Tensor
|
||||
return x + a + b
|
||||
|
||||
@ -3113,7 +3117,7 @@ class TestScript(JitTestCase):
|
||||
def sum_list(a):
|
||||
# type: (int) -> int
|
||||
sum = 0
|
||||
for i in a:
|
||||
for i in a: # noqa: T484
|
||||
sum += i
|
||||
|
||||
return sum
|
||||
@ -4727,23 +4731,23 @@ a")
|
||||
x = 1
|
||||
else:
|
||||
x = torch.jit._unwrap_optional(x)
|
||||
return x
|
||||
return x # noqa: T484
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
def or_error(x, y):
|
||||
# type: (Optional[int], Optional[int]) -> int
|
||||
# type: (Optional[int], Optional[int]) -> None
|
||||
if x is None or y is None:
|
||||
print(x + y)
|
||||
print(x + y) # noqa: T484
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
def and_error(x, y):
|
||||
# type: (Optional[int], Optional[int]) -> int
|
||||
# type: (Optional[int], Optional[int]) -> None
|
||||
if x is None and y is None:
|
||||
pass
|
||||
else:
|
||||
print(x + y)
|
||||
print(x + y) # noqa: T484
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
@ -4751,7 +4755,7 @@ a")
|
||||
# type: (Optional[int]) -> None
|
||||
x_none = x is not None
|
||||
if x_none:
|
||||
print(x + 1)
|
||||
print(x + 1) # noqa: T484
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
@ -4759,7 +4763,7 @@ a")
|
||||
# type: (Optional[int], Optional[int]) -> None
|
||||
x_none = x is not None
|
||||
if y is not None and x_none:
|
||||
print(x + y)
|
||||
print(x + y) # noqa: T484
|
||||
|
||||
def test_while_write_outer_then_read(self):
|
||||
def func(a, b):
|
||||
@ -5057,10 +5061,11 @@ a")
|
||||
self.checkScript(multiple_returns, [a], optimize=True)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
|
||||
@torch.jit.script
|
||||
torch.jit.CompilationUnit('''
|
||||
def no_return_bad_annotation(a):
|
||||
# type: (Tensor) -> Tensor
|
||||
a + 1
|
||||
''')
|
||||
|
||||
def test_error(self):
|
||||
@torch.jit.script
|
||||
@ -5654,8 +5659,6 @@ a")
|
||||
hiddens = hx
|
||||
|
||||
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
||||
from typing import Tuple
|
||||
|
||||
class ScriptWrapper(torch.jit.ScriptModule):
|
||||
def __init__(self, cell):
|
||||
super(ScriptWrapper, self).__init__()
|
||||
@ -6650,7 +6653,7 @@ a")
|
||||
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
|
||||
def foo():
|
||||
# type: () -> Tensor
|
||||
return ((3, 4),)
|
||||
return ((3, 4),) # noqa: T484
|
||||
|
||||
@torch.jit.script
|
||||
def bar():
|
||||
@ -6769,7 +6772,7 @@ a")
|
||||
if x:
|
||||
y = [1]
|
||||
else:
|
||||
y = [None]
|
||||
y = [None] # noqa: T484
|
||||
return y[0]
|
||||
|
||||
@torch.jit.script
|
||||
@ -6815,18 +6818,18 @@ a")
|
||||
print(int_fn((1, 1, 1)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
|
||||
@torch.jit.script
|
||||
@torch.jit.script # noqa: T484
|
||||
def fn(x):
|
||||
# type: (BroadcastingListx[int]) -> List[int]
|
||||
# type: (BroadcastingListx[int]) -> List[int] # noqa: T484
|
||||
return x
|
||||
|
||||
# TODO: the type comment in this seems to trip up flake8 for some reason
|
||||
# even though we have a noqa comment. Figure out why
|
||||
# using CU so that flake8 error on int[2] is not raised (noqa not working)
|
||||
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
|
||||
@torch.jit.script
|
||||
def nested(x, y):
|
||||
# type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484
|
||||
return x
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def nested(x, y):
|
||||
# type: (int, Tuple[int, int[2]]) -> List[int]
|
||||
return x # noqa: T484
|
||||
''')
|
||||
|
||||
def test_ntuple_builtins(self):
|
||||
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
||||
@ -8349,7 +8352,7 @@ a")
|
||||
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
|
||||
def somefunc():
|
||||
# type: () -> Tuple[Tuple[Tensor, Tensor]]
|
||||
return torch.zeros(3, 4), torch.zeros(4, 5)
|
||||
return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
|
||||
|
||||
@torch.jit.script
|
||||
def wrong_return_type():
|
||||
@ -9029,7 +9032,7 @@ a")
|
||||
def test(x):
|
||||
# type: (Optional[int]) -> int
|
||||
x = torch.jit._unwrap_optional(x)
|
||||
x = x + x
|
||||
x = x + x # noqa: T484
|
||||
return x
|
||||
|
||||
self.checkScript(test, (3,))
|
||||
@ -9082,14 +9085,14 @@ a")
|
||||
@torch.jit.script
|
||||
def return_tup(x):
|
||||
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
|
||||
return x, x
|
||||
return x, x # noqa: T484
|
||||
|
||||
def test_annotated_script_fn_arg_mismatch(self):
|
||||
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
def tuple_arg(x):
|
||||
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
||||
return x + 1
|
||||
return x + 1 # noqa: T484
|
||||
|
||||
def test_script_non_tensor_args_outputs(self):
|
||||
@torch.jit.script
|
||||
@ -13122,11 +13125,11 @@ class TestAsync(JitTestCase):
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
def test_async_script_capture(self):
|
||||
class Module(torch.jit.ScriptModule):
|
||||
class Mod(torch.jit.ScriptModule):
|
||||
__constants__ = ['const']
|
||||
|
||||
def __init__(self):
|
||||
super(Module, self).__init__(False)
|
||||
super(Mod, self).__init__(False)
|
||||
self.const = 42
|
||||
self.param = nn.Parameter(torch.randn(2, 2))
|
||||
|
||||
@ -13144,7 +13147,7 @@ class TestAsync(JitTestCase):
|
||||
x1 = torch.rand(3, 4)
|
||||
x2 = torch.rand(5, 6)
|
||||
|
||||
m = Module()
|
||||
m = Mod()
|
||||
y, y_hat = m.wait_script(x1, x2)
|
||||
|
||||
self.assertEqual(y, y_hat)
|
||||
@ -13244,9 +13247,9 @@ class TestAsync(JitTestCase):
|
||||
def forward(self, x):
|
||||
return (torch.neg(x), x)
|
||||
|
||||
class Module(torch.jit.ScriptModule):
|
||||
class Mod(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(Module, self).__init__(False)
|
||||
super(Mod, self).__init__(False)
|
||||
x = torch.rand(3, 3)
|
||||
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
|
||||
|
||||
@ -13266,10 +13269,10 @@ class TestAsync(JitTestCase):
|
||||
# return a nested structure of tensors
|
||||
return (tensor_list, tensor_tuple, tensor_tuple[1])
|
||||
|
||||
class Tuple(nn.Module):
|
||||
class TupleCl(nn.Module):
|
||||
def __init__(self):
|
||||
super(Tuple, self).__init__()
|
||||
self.module = Module()
|
||||
super(TupleCl, self).__init__()
|
||||
self.module = Mod()
|
||||
|
||||
def forward(self, x):
|
||||
z = torch.neg(x)
|
||||
@ -13278,7 +13281,7 @@ class TestAsync(JitTestCase):
|
||||
return tuple(list)
|
||||
|
||||
x = torch.rand(3, 3)
|
||||
module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
|
||||
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
|
||||
|
||||
# Make sure we have forks
|
||||
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
|
||||
@ -13632,16 +13635,16 @@ class TestClassType(JitTestCase):
|
||||
@torch.jit.script
|
||||
class FooTest:
|
||||
def __init__(self, x):
|
||||
# type: (int)
|
||||
# type: (int) -> None
|
||||
self.foo = x
|
||||
|
||||
def incFooTest(self, y):
|
||||
# type: (int)
|
||||
# type: (int) -> None
|
||||
self.foo = self.foo + y
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
# type: (int)
|
||||
# type: (int) -> int
|
||||
foo = FooTest(x)
|
||||
foo.incFooTest(2)
|
||||
return foo.foo
|
||||
@ -13689,7 +13692,7 @@ class TestClassType(JitTestCase):
|
||||
@torch.jit.script
|
||||
class FooTest:
|
||||
def __init__(self, x):
|
||||
# type: (bool)
|
||||
# type: (bool) -> None
|
||||
self.foo = x
|
||||
|
||||
@torch.jit.script
|
||||
@ -13718,7 +13721,7 @@ class TestClassType(JitTestCase):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(foo):
|
||||
# type: (FooTest)
|
||||
# type: (FooTest) -> Tensor
|
||||
return foo.attr
|
||||
|
||||
@torch.jit.script
|
||||
|
@ -9,24 +9,24 @@ import inspect
|
||||
from torch._six import builtins
|
||||
|
||||
# Tracks standalone weak script functions
|
||||
compiled_weak_fns = weakref.WeakKeyDictionary()
|
||||
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Tracks which methods should be converted to strong methods
|
||||
weak_script_methods = weakref.WeakKeyDictionary()
|
||||
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Converted modules and their corresponding WeakScriptModuleProxy objects
|
||||
weak_modules = weakref.WeakKeyDictionary()
|
||||
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Types that have been declared as weak modules
|
||||
weak_types = weakref.WeakKeyDictionary()
|
||||
weak_types = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Wrapper functions that can call either of 2 functions depending on a boolean
|
||||
# argument
|
||||
boolean_dispatched = weakref.WeakKeyDictionary()
|
||||
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
# Python Op functions that should be ignored by the compiler. These will be replaced
|
||||
# with an operator that always throws an error
|
||||
ignored_fns = weakref.WeakSet()
|
||||
ignored_fns = weakref.WeakSet() # noqa: T484
|
||||
|
||||
COMPILATION_PENDING = object()
|
||||
COMPILED = object()
|
||||
@ -223,9 +223,9 @@ except ImportError:
|
||||
def __getitem__(self, types):
|
||||
return DictInstance(types)
|
||||
|
||||
Tuple = TupleCls()
|
||||
List = ListCls()
|
||||
Dict = DictCls()
|
||||
Tuple = TupleCls() # noqa: T484
|
||||
List = ListCls() # noqa: T484
|
||||
Dict = DictCls() # noqa: T484
|
||||
|
||||
def is_tuple(ann):
|
||||
return isinstance(ann, TupleInstance)
|
||||
|
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import copy
|
||||
import numbers
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
from torch import Tensor
|
||||
from torch.jit import ScriptModule
|
||||
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn import _VF
|
||||
|
Reference in New Issue
Block a user