mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use flake8-mypy (#17721)
Summary: Use flake8 installed with mypy checks so that our linter matches fbcode. Mypy type errors also provide valuable signal Pull Request resolved: https://github.com/pytorch/pytorch/pull/17721 Differential Revision: D14357778 Pulled By: eellison fbshipit-source-id: d8c9ea3fe3b5f550c3b70fe259e0eabf95e4c92d
This commit is contained in:
committed by
Facebook Github Bot
parent
1d522598fb
commit
561037aef8
2
.flake8
2
.flake8
@ -1,4 +1,4 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
|
ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
|
||||||
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install
|
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include
|
||||||
|
@ -27,5 +27,5 @@ matrix:
|
|||||||
include:
|
include:
|
||||||
env: LINT_CHECK
|
env: LINT_CHECK
|
||||||
python: "2.7"
|
python: "2.7"
|
||||||
install: pip install flake8
|
install: pip install flake8-mypy
|
||||||
script: flake8
|
script: flake8
|
||||||
|
@ -29,7 +29,7 @@ matrix:
|
|||||||
python: "3.7"
|
python: "3.7"
|
||||||
dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069)
|
dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069)
|
||||||
sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069)
|
sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069)
|
||||||
install: pip install flake8
|
install: pip install flake8-mypy
|
||||||
script: flake8
|
script: flake8
|
||||||
- name: "MyPy typecheck"
|
- name: "MyPy typecheck"
|
||||||
python: "3.6"
|
python: "3.6"
|
||||||
|
@ -5,11 +5,13 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn.parallel as dp
|
import torch.nn.parallel as dp
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import torch.cuda
|
||||||
import torch.jit.quantized
|
import torch.jit.quantized
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from itertools import product, chain
|
from itertools import product, chain
|
||||||
import torch.jit.frontend
|
import torch.jit.frontend
|
||||||
from torch.autograd import Variable, Function
|
from torch.autograd import Variable, Function
|
||||||
|
from torch.nn import Module
|
||||||
from torch.autograd.function import traceable
|
from torch.autograd.function import traceable
|
||||||
from torch.testing import assert_allclose
|
from torch.testing import assert_allclose
|
||||||
from torch.onnx import OperatorExportTypes
|
from torch.onnx import OperatorExportTypes
|
||||||
@ -44,9 +46,11 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \
|
|||||||
ListType, StringType, DictType
|
ListType, StringType, DictType
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import random
|
import random
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional, Tuple
|
||||||
from torch.jit.frontend import NotSupportedError
|
from torch.jit.frontend import NotSupportedError
|
||||||
from torch.jit import BatchTensor
|
from torch.jit import BatchTensor
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
|
||||||
|
|
||||||
# For testing truediv in python 2
|
# For testing truediv in python 2
|
||||||
from test_module.future_div import div_int_future, div_float_future
|
from test_module.future_div import div_int_future, div_float_future
|
||||||
@ -96,7 +100,7 @@ if WINDOWS:
|
|||||||
finally:
|
finally:
|
||||||
os.unlink(f.name)
|
os.unlink(f.name)
|
||||||
else:
|
else:
|
||||||
@contextmanager
|
@contextmanager # noqa: T484
|
||||||
def TemporaryFileName():
|
def TemporaryFileName():
|
||||||
with tempfile.NamedTemporaryFile() as f:
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
yield f.name
|
yield f.name
|
||||||
@ -2262,7 +2266,7 @@ class TestJit(JitTestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
|
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def hints_bad_types(x, a=10, b=0.5):
|
def hints_bad_types(x, a=10, b=0.5): # noqa: T484
|
||||||
# type: (Tensor, float, int) -> Tensor
|
# type: (Tensor, float, int) -> Tensor
|
||||||
return x + a + b
|
return x + a + b
|
||||||
|
|
||||||
@ -3113,7 +3117,7 @@ class TestScript(JitTestCase):
|
|||||||
def sum_list(a):
|
def sum_list(a):
|
||||||
# type: (int) -> int
|
# type: (int) -> int
|
||||||
sum = 0
|
sum = 0
|
||||||
for i in a:
|
for i in a: # noqa: T484
|
||||||
sum += i
|
sum += i
|
||||||
|
|
||||||
return sum
|
return sum
|
||||||
@ -4727,23 +4731,23 @@ a")
|
|||||||
x = 1
|
x = 1
|
||||||
else:
|
else:
|
||||||
x = torch.jit._unwrap_optional(x)
|
x = torch.jit._unwrap_optional(x)
|
||||||
return x
|
return x # noqa: T484
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def or_error(x, y):
|
def or_error(x, y):
|
||||||
# type: (Optional[int], Optional[int]) -> int
|
# type: (Optional[int], Optional[int]) -> None
|
||||||
if x is None or y is None:
|
if x is None or y is None:
|
||||||
print(x + y)
|
print(x + y) # noqa: T484
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def and_error(x, y):
|
def and_error(x, y):
|
||||||
# type: (Optional[int], Optional[int]) -> int
|
# type: (Optional[int], Optional[int]) -> None
|
||||||
if x is None and y is None:
|
if x is None and y is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(x + y)
|
print(x + y) # noqa: T484
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@ -4751,7 +4755,7 @@ a")
|
|||||||
# type: (Optional[int]) -> None
|
# type: (Optional[int]) -> None
|
||||||
x_none = x is not None
|
x_none = x is not None
|
||||||
if x_none:
|
if x_none:
|
||||||
print(x + 1)
|
print(x + 1) # noqa: T484
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@ -4759,7 +4763,7 @@ a")
|
|||||||
# type: (Optional[int], Optional[int]) -> None
|
# type: (Optional[int], Optional[int]) -> None
|
||||||
x_none = x is not None
|
x_none = x is not None
|
||||||
if y is not None and x_none:
|
if y is not None and x_none:
|
||||||
print(x + y)
|
print(x + y) # noqa: T484
|
||||||
|
|
||||||
def test_while_write_outer_then_read(self):
|
def test_while_write_outer_then_read(self):
|
||||||
def func(a, b):
|
def func(a, b):
|
||||||
@ -5057,10 +5061,11 @@ a")
|
|||||||
self.checkScript(multiple_returns, [a], optimize=True)
|
self.checkScript(multiple_returns, [a], optimize=True)
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
|
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
|
||||||
@torch.jit.script
|
torch.jit.CompilationUnit('''
|
||||||
def no_return_bad_annotation(a):
|
def no_return_bad_annotation(a):
|
||||||
# type: (Tensor) -> Tensor
|
# type: (Tensor) -> Tensor
|
||||||
a + 1
|
a + 1
|
||||||
|
''')
|
||||||
|
|
||||||
def test_error(self):
|
def test_error(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@ -5654,8 +5659,6 @@ a")
|
|||||||
hiddens = hx
|
hiddens = hx
|
||||||
|
|
||||||
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
class ScriptWrapper(torch.jit.ScriptModule):
|
class ScriptWrapper(torch.jit.ScriptModule):
|
||||||
def __init__(self, cell):
|
def __init__(self, cell):
|
||||||
super(ScriptWrapper, self).__init__()
|
super(ScriptWrapper, self).__init__()
|
||||||
@ -6650,7 +6653,7 @@ a")
|
|||||||
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
|
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
|
||||||
def foo():
|
def foo():
|
||||||
# type: () -> Tensor
|
# type: () -> Tensor
|
||||||
return ((3, 4),)
|
return ((3, 4),) # noqa: T484
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def bar():
|
def bar():
|
||||||
@ -6769,7 +6772,7 @@ a")
|
|||||||
if x:
|
if x:
|
||||||
y = [1]
|
y = [1]
|
||||||
else:
|
else:
|
||||||
y = [None]
|
y = [None] # noqa: T484
|
||||||
return y[0]
|
return y[0]
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@ -6815,18 +6818,18 @@ a")
|
|||||||
print(int_fn((1, 1, 1)))
|
print(int_fn((1, 1, 1)))
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
|
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
|
||||||
@torch.jit.script
|
@torch.jit.script # noqa: T484
|
||||||
def fn(x):
|
def fn(x):
|
||||||
# type: (BroadcastingListx[int]) -> List[int]
|
# type: (BroadcastingListx[int]) -> List[int] # noqa: T484
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# TODO: the type comment in this seems to trip up flake8 for some reason
|
# using CU so that flake8 error on int[2] is not raised (noqa not working)
|
||||||
# even though we have a noqa comment. Figure out why
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
|
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
|
||||||
@torch.jit.script
|
cu = torch.jit.CompilationUnit('''
|
||||||
def nested(x, y):
|
def nested(x, y):
|
||||||
# type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484
|
# type: (int, Tuple[int, int[2]]) -> List[int]
|
||||||
return x
|
return x # noqa: T484
|
||||||
|
''')
|
||||||
|
|
||||||
def test_ntuple_builtins(self):
|
def test_ntuple_builtins(self):
|
||||||
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
||||||
@ -8349,7 +8352,7 @@ a")
|
|||||||
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
|
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
|
||||||
def somefunc():
|
def somefunc():
|
||||||
# type: () -> Tuple[Tuple[Tensor, Tensor]]
|
# type: () -> Tuple[Tuple[Tensor, Tensor]]
|
||||||
return torch.zeros(3, 4), torch.zeros(4, 5)
|
return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def wrong_return_type():
|
def wrong_return_type():
|
||||||
@ -9029,7 +9032,7 @@ a")
|
|||||||
def test(x):
|
def test(x):
|
||||||
# type: (Optional[int]) -> int
|
# type: (Optional[int]) -> int
|
||||||
x = torch.jit._unwrap_optional(x)
|
x = torch.jit._unwrap_optional(x)
|
||||||
x = x + x
|
x = x + x # noqa: T484
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.checkScript(test, (3,))
|
self.checkScript(test, (3,))
|
||||||
@ -9082,14 +9085,14 @@ a")
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def return_tup(x):
|
def return_tup(x):
|
||||||
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
|
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
|
||||||
return x, x
|
return x, x # noqa: T484
|
||||||
|
|
||||||
def test_annotated_script_fn_arg_mismatch(self):
|
def test_annotated_script_fn_arg_mismatch(self):
|
||||||
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
|
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def tuple_arg(x):
|
def tuple_arg(x):
|
||||||
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
||||||
return x + 1
|
return x + 1 # noqa: T484
|
||||||
|
|
||||||
def test_script_non_tensor_args_outputs(self):
|
def test_script_non_tensor_args_outputs(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@ -13122,11 +13125,11 @@ class TestAsync(JitTestCase):
|
|||||||
self.assertEqual(y, y_hat)
|
self.assertEqual(y, y_hat)
|
||||||
|
|
||||||
def test_async_script_capture(self):
|
def test_async_script_capture(self):
|
||||||
class Module(torch.jit.ScriptModule):
|
class Mod(torch.jit.ScriptModule):
|
||||||
__constants__ = ['const']
|
__constants__ = ['const']
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Module, self).__init__(False)
|
super(Mod, self).__init__(False)
|
||||||
self.const = 42
|
self.const = 42
|
||||||
self.param = nn.Parameter(torch.randn(2, 2))
|
self.param = nn.Parameter(torch.randn(2, 2))
|
||||||
|
|
||||||
@ -13144,7 +13147,7 @@ class TestAsync(JitTestCase):
|
|||||||
x1 = torch.rand(3, 4)
|
x1 = torch.rand(3, 4)
|
||||||
x2 = torch.rand(5, 6)
|
x2 = torch.rand(5, 6)
|
||||||
|
|
||||||
m = Module()
|
m = Mod()
|
||||||
y, y_hat = m.wait_script(x1, x2)
|
y, y_hat = m.wait_script(x1, x2)
|
||||||
|
|
||||||
self.assertEqual(y, y_hat)
|
self.assertEqual(y, y_hat)
|
||||||
@ -13244,9 +13247,9 @@ class TestAsync(JitTestCase):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return (torch.neg(x), x)
|
return (torch.neg(x), x)
|
||||||
|
|
||||||
class Module(torch.jit.ScriptModule):
|
class Mod(torch.jit.ScriptModule):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Module, self).__init__(False)
|
super(Mod, self).__init__(False)
|
||||||
x = torch.rand(3, 3)
|
x = torch.rand(3, 3)
|
||||||
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
|
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
|
||||||
|
|
||||||
@ -13266,10 +13269,10 @@ class TestAsync(JitTestCase):
|
|||||||
# return a nested structure of tensors
|
# return a nested structure of tensors
|
||||||
return (tensor_list, tensor_tuple, tensor_tuple[1])
|
return (tensor_list, tensor_tuple, tensor_tuple[1])
|
||||||
|
|
||||||
class Tuple(nn.Module):
|
class TupleCl(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Tuple, self).__init__()
|
super(TupleCl, self).__init__()
|
||||||
self.module = Module()
|
self.module = Mod()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
z = torch.neg(x)
|
z = torch.neg(x)
|
||||||
@ -13278,7 +13281,7 @@ class TestAsync(JitTestCase):
|
|||||||
return tuple(list)
|
return tuple(list)
|
||||||
|
|
||||||
x = torch.rand(3, 3)
|
x = torch.rand(3, 3)
|
||||||
module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
|
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
|
||||||
|
|
||||||
# Make sure we have forks
|
# Make sure we have forks
|
||||||
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
|
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
|
||||||
@ -13632,16 +13635,16 @@ class TestClassType(JitTestCase):
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class FooTest:
|
class FooTest:
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
# type: (int)
|
# type: (int) -> None
|
||||||
self.foo = x
|
self.foo = x
|
||||||
|
|
||||||
def incFooTest(self, y):
|
def incFooTest(self, y):
|
||||||
# type: (int)
|
# type: (int) -> None
|
||||||
self.foo = self.foo + y
|
self.foo = self.foo + y
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def fn(x):
|
def fn(x):
|
||||||
# type: (int)
|
# type: (int) -> int
|
||||||
foo = FooTest(x)
|
foo = FooTest(x)
|
||||||
foo.incFooTest(2)
|
foo.incFooTest(2)
|
||||||
return foo.foo
|
return foo.foo
|
||||||
@ -13689,7 +13692,7 @@ class TestClassType(JitTestCase):
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class FooTest:
|
class FooTest:
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
# type: (bool)
|
# type: (bool) -> None
|
||||||
self.foo = x
|
self.foo = x
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@ -13718,7 +13721,7 @@ class TestClassType(JitTestCase):
|
|||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def fn(foo):
|
def fn(foo):
|
||||||
# type: (FooTest)
|
# type: (FooTest) -> Tensor
|
||||||
return foo.attr
|
return foo.attr
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
@ -9,24 +9,24 @@ import inspect
|
|||||||
from torch._six import builtins
|
from torch._six import builtins
|
||||||
|
|
||||||
# Tracks standalone weak script functions
|
# Tracks standalone weak script functions
|
||||||
compiled_weak_fns = weakref.WeakKeyDictionary()
|
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
|
||||||
|
|
||||||
# Tracks which methods should be converted to strong methods
|
# Tracks which methods should be converted to strong methods
|
||||||
weak_script_methods = weakref.WeakKeyDictionary()
|
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
|
||||||
|
|
||||||
# Converted modules and their corresponding WeakScriptModuleProxy objects
|
# Converted modules and their corresponding WeakScriptModuleProxy objects
|
||||||
weak_modules = weakref.WeakKeyDictionary()
|
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
|
||||||
|
|
||||||
# Types that have been declared as weak modules
|
# Types that have been declared as weak modules
|
||||||
weak_types = weakref.WeakKeyDictionary()
|
weak_types = weakref.WeakKeyDictionary() # noqa: T484
|
||||||
|
|
||||||
# Wrapper functions that can call either of 2 functions depending on a boolean
|
# Wrapper functions that can call either of 2 functions depending on a boolean
|
||||||
# argument
|
# argument
|
||||||
boolean_dispatched = weakref.WeakKeyDictionary()
|
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
|
||||||
|
|
||||||
# Python Op functions that should be ignored by the compiler. These will be replaced
|
# Python Op functions that should be ignored by the compiler. These will be replaced
|
||||||
# with an operator that always throws an error
|
# with an operator that always throws an error
|
||||||
ignored_fns = weakref.WeakSet()
|
ignored_fns = weakref.WeakSet() # noqa: T484
|
||||||
|
|
||||||
COMPILATION_PENDING = object()
|
COMPILATION_PENDING = object()
|
||||||
COMPILED = object()
|
COMPILED = object()
|
||||||
@ -223,9 +223,9 @@ except ImportError:
|
|||||||
def __getitem__(self, types):
|
def __getitem__(self, types):
|
||||||
return DictInstance(types)
|
return DictInstance(types)
|
||||||
|
|
||||||
Tuple = TupleCls()
|
Tuple = TupleCls() # noqa: T484
|
||||||
List = ListCls()
|
List = ListCls() # noqa: T484
|
||||||
Dict = DictCls()
|
Dict = DictCls() # noqa: T484
|
||||||
|
|
||||||
def is_tuple(ann):
|
def is_tuple(ann):
|
||||||
return isinstance(ann, TupleInstance)
|
return isinstance(ann, TupleInstance)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import numbers
|
import numbers
|
||||||
from typing import Tuple
|
from typing import Tuple, Optional
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.jit import ScriptModule
|
||||||
|
|
||||||
from torch.nn.utils.rnn import PackedSequence
|
from torch.nn.utils.rnn import PackedSequence
|
||||||
from torch.nn import _VF
|
from torch.nn import _VF
|
||||||
|
Reference in New Issue
Block a user