Enable typechecks for torch.nn.modules.[activation|upsampling] (#44093)

Summary:
Add missing `hardsigmoid`, `silu`, `hardswish` and `multi_head_attention_forward` to functional.pyi.in
 Embed some typing annotations into functional.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44093

Reviewed By: ezyang

Differential Revision: D23494384

Pulled By: malfet

fbshipit-source-id: 27023c16ff5951ceaebb78799c4629efa25f7c5c
This commit is contained in:
Nikita Shulga
2020-09-03 13:17:26 -07:00
committed by Facebook GitHub Bot
parent a153f69417
commit 442684cb25
5 changed files with 70 additions and 48 deletions

View File

@ -98,9 +98,6 @@ ignore_errors = True
[mypy-torch._tensor_str]
ignore_errors = True
[mypy-torch.nn.modules.activation]
ignore_errors = True
[mypy-torch.nn.modules.batchnorm]
ignore_errors = True
@ -140,9 +137,6 @@ ignore_errors = True
[mypy-torch.nn.modules.sparse]
ignore_errors = True
[mypy-torch.nn.modules.upsampling]
ignore_errors = True
[mypy-torch.nn.parallel._functions]
ignore_errors = True

View File

@ -11,7 +11,7 @@ from .modules import utils
from .modules.utils import _single, _pair, _triple, _list_with_default
from . import grad # noqa: F401
from torch import _VF
from .._jit_internal import boolean_dispatch, List, Optional, _overload
from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
from ..overrides import has_torch_function, handle_torch_function
@ -1111,8 +1111,7 @@ In-place version of :func:`~threshold`.
""")
def relu(input, inplace=False):
# type: (Tensor, bool) -> Tensor
def relu(input: Tensor, inplace: bool = False) -> Tensor:
r"""relu(input, inplace=False) -> Tensor
Applies the rectified linear unit function element-wise. See
@ -1135,8 +1134,7 @@ In-place version of :func:`~relu`.
""")
def glu(input, dim=-1):
# type: (Tensor, int) -> Tensor
def glu(input: Tensor, dim: int = -1) -> Tensor:
r"""
glu(input, dim=-1) -> Tensor
@ -1162,8 +1160,7 @@ def glu(input, dim=-1):
return torch._C._nn.glu(input, dim)
def hardtanh(input, min_val=-1., max_val=1., inplace=False):
# type: (Tensor, float, float, bool) -> Tensor
def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor:
r"""
hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor
@ -1732,8 +1729,7 @@ def silu(input, inplace=False):
return torch._C._nn.silu_(input)
return torch._C._nn.silu(input)
def hardswish(input, inplace=False):
# type: (Tensor, bool) -> Tensor
def hardswish(input: Tensor, inplace: bool = False) -> Tensor:
r"""Applies the hardswish function, element-wise, as described in the paper:
`Searching for MobileNetV3`_.
@ -3857,31 +3853,30 @@ def _pad_circular(input, padding):
return input
def multi_head_attention_forward(query, # type: Tensor
key, # type: Tensor
value, # type: Tensor
embed_dim_to_check, # type: int
num_heads, # type: int
in_proj_weight, # type: Tensor
in_proj_bias, # type: Tensor
bias_k, # type: Optional[Tensor]
bias_v, # type: Optional[Tensor]
add_zero_attn, # type: bool
dropout_p, # type: float
out_proj_weight, # type: Tensor
out_proj_bias, # type: Tensor
training=True, # type: bool
key_padding_mask=None, # type: Optional[Tensor]
need_weights=True, # type: bool
attn_mask=None, # type: Optional[Tensor]
use_separate_proj_weight=False, # type: bool
q_proj_weight=None, # type: Optional[Tensor]
k_proj_weight=None, # type: Optional[Tensor]
v_proj_weight=None, # type: Optional[Tensor]
static_k=None, # type: Optional[Tensor]
static_v=None # type: Optional[Tensor]
):
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
def multi_head_attention_forward(query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.

View File

@ -166,7 +166,9 @@ def log_softmax(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ...,
def tanh(input: Any): ...
def sigmoid(input: Any): ...
def sigmoid(input: Any) -> Tensor: ...
def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ...
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = ...) -> Tensor: ...
@ -175,6 +177,12 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = ...) -> Tenso
def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = ...) -> Tensor: ...
def silu(input: Tensor, inplace: bool = False) -> Tensor: ...
def hardswish(input: Tensor, inplace: bool = False) -> Tensor: ...
def embedding(input: Tensor, weight: Tensor, padding_idx: Optional[int] = ..., max_norm: Optional[float] = ...,
norm_type: float = ..., scale_grad_by_freq: bool = ..., sparse: bool = ...) -> Tensor: ...
@ -325,6 +333,33 @@ def unfold(input: Tensor, kernel_size: _size, dilation: _size = ..., padding: _s
def fold(input: Tensor, output_size: _size, kernel_size: _size, dilation: _size = ..., padding: _size = ...,
stride: _size = ...) -> Tensor: ...
def multi_head_attention_forward(query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]: ...
${imported_hints}
${dispatched_hints}

View File

@ -853,10 +853,8 @@ class MultiheadAttention(Module):
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__annotations__ = {
'bias_k': torch._jit_internal.Optional[torch.Tensor],
'bias_v': torch._jit_internal.Optional[torch.Tensor],
}
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
super(MultiheadAttention, self).__init__()

View File

@ -120,10 +120,10 @@ class Upsample(Module):
"""
__constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name']
name: str
size: _size_any_t
scale_factor: _ratio_any_t
size: Optional[_size_any_t]
scale_factor: Optional[_ratio_any_t]
mode: str
align_corners: bool
align_corners: Optional[bool]
def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None,
mode: str = 'nearest', align_corners: Optional[bool] = None) -> None: