mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 10:34:54 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497397 Pulled By: ezyang fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
256 lines
9.3 KiB
Python
256 lines
9.3 KiB
Python
import torch
|
|
import numbers
|
|
from torch.nn.parameter import Parameter
|
|
from .module import Module
|
|
from ._functions import CrossMapLRN2d as _cross_map_lrn2d
|
|
from .. import functional as F
|
|
from .. import init
|
|
|
|
from torch import Tensor, Size
|
|
from typing import Union, List
|
|
|
|
|
|
class LocalResponseNorm(Module):
|
|
r"""Applies local response normalization over an input signal composed
|
|
of several input planes, where channels occupy the second dimension.
|
|
Applies normalization across channels.
|
|
|
|
.. math::
|
|
b_{c} = a_{c}\left(k + \frac{\alpha}{n}
|
|
\sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
|
|
|
|
Args:
|
|
size: amount of neighbouring channels used for normalization
|
|
alpha: multiplicative factor. Default: 0.0001
|
|
beta: exponent. Default: 0.75
|
|
k: additive factor. Default: 1
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, *)`
|
|
- Output: :math:`(N, C, *)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> lrn = nn.LocalResponseNorm(2)
|
|
>>> signal_2d = torch.randn(32, 5, 24, 24)
|
|
>>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
|
|
>>> output_2d = lrn(signal_2d)
|
|
>>> output_4d = lrn(signal_4d)
|
|
|
|
"""
|
|
__constants__ = ['size', 'alpha', 'beta', 'k']
|
|
size: int
|
|
alpha: float
|
|
beta: float
|
|
k: float
|
|
|
|
def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None:
|
|
super(LocalResponseNorm, self).__init__()
|
|
self.size = size
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.k = k
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.local_response_norm(input, self.size, self.alpha, self.beta,
|
|
self.k)
|
|
|
|
def extra_repr(self):
|
|
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
|
|
|
|
|
class CrossMapLRN2d(Module):
|
|
size: int
|
|
alpha: float
|
|
beta: float
|
|
k: float
|
|
|
|
def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None:
|
|
super(CrossMapLRN2d, self).__init__()
|
|
self.size = size
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.k = k
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta,
|
|
self.k)
|
|
|
|
def extra_repr(self) -> str:
|
|
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
|
|
|
|
|
_shape_t = Union[int, List[int], Size]
|
|
|
|
|
|
class LayerNorm(Module):
|
|
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
|
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
|
|
|
|
.. math::
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated separately over the last
|
|
certain number dimensions which have to be of the shape specified by
|
|
:attr:`normalized_shape`.
|
|
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
|
|
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
|
|
The standard-deviation is calculated via the biased estimator, equivalent to
|
|
`torch.var(input, unbiased=False)`.
|
|
|
|
.. note::
|
|
Unlike Batch Normalization and Instance Normalization, which applies
|
|
scalar scale and bias for each entire channel/plane with the
|
|
:attr:`affine` option, Layer Normalization applies per-element scale and
|
|
bias with :attr:`elementwise_affine`.
|
|
|
|
This layer uses statistics computed from input data in both training and
|
|
evaluation modes.
|
|
|
|
Args:
|
|
normalized_shape (int or list or torch.Size): input shape from an expected input
|
|
of size
|
|
|
|
.. math::
|
|
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
|
|
\times \ldots \times \text{normalized\_shape}[-1]]
|
|
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
elementwise_affine: a boolean value that when set to ``True``, this module
|
|
has learnable per-element affine parameters initialized to ones (for weights)
|
|
and zeros (for biases). Default: ``True``.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)`
|
|
- Output: :math:`(N, *)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(20, 5, 10, 10)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.LayerNorm(input.size()[1:])
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
|
|
>>> # Normalize over last two dimensions
|
|
>>> m = nn.LayerNorm([10, 10])
|
|
>>> # Normalize over last dimension of size 10
|
|
>>> m = nn.LayerNorm(10)
|
|
>>> # Activating the module
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
|
normalized_shape: _shape_t
|
|
eps: float
|
|
elementwise_affine: bool
|
|
|
|
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
|
|
super(LayerNorm, self).__init__()
|
|
if isinstance(normalized_shape, numbers.Integral):
|
|
normalized_shape = (normalized_shape,)
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = Parameter(torch.Tensor(*normalized_shape))
|
|
self.bias = Parameter(torch.Tensor(*normalized_shape))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
if self.elementwise_affine:
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.layer_norm(
|
|
input, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
def extra_repr(self) -> Tensor:
|
|
return '{normalized_shape}, eps={eps}, ' \
|
|
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
|
|
|
|
|
class GroupNorm(Module):
|
|
r"""Applies Group Normalization over a mini-batch of inputs as described in
|
|
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
|
|
|
|
.. math::
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The input channels are separated into :attr:`num_groups` groups, each containing
|
|
``num_channels / num_groups`` channels. The mean and standard-deviation are calculated
|
|
separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
|
|
per-channel affine transform parameter vectors of size :attr:`num_channels` if
|
|
:attr:`affine` is ``True``.
|
|
The standard-deviation is calculated via the biased estimator, equivalent to
|
|
`torch.var(input, unbiased=False)`.
|
|
|
|
This layer uses statistics computed from input data in both training and
|
|
evaluation modes.
|
|
|
|
Args:
|
|
num_groups (int): number of groups to separate the channels into
|
|
num_channels (int): number of channels expected in input
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
affine: a boolean value that when set to ``True``, this module
|
|
has learnable per-channel affine parameters initialized to ones (for weights)
|
|
and zeros (for biases). Default: ``True``.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
|
|
- Output: :math:`(N, C, *)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(20, 6, 10, 10)
|
|
>>> # Separate 6 channels into 3 groups
|
|
>>> m = nn.GroupNorm(3, 6)
|
|
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
|
|
>>> m = nn.GroupNorm(6, 6)
|
|
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
|
|
>>> m = nn.GroupNorm(1, 6)
|
|
>>> # Activating the module
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
|
|
num_groups: int
|
|
num_channels: int
|
|
eps: float
|
|
affine: bool
|
|
|
|
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
|
|
super(GroupNorm, self).__init__()
|
|
self.num_groups = num_groups
|
|
self.num_channels = num_channels
|
|
self.eps = eps
|
|
self.affine = affine
|
|
if self.affine:
|
|
self.weight = Parameter(torch.Tensor(num_channels))
|
|
self.bias = Parameter(torch.Tensor(num_channels))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
if self.affine:
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.group_norm(
|
|
input, self.num_groups, self.weight, self.bias, self.eps)
|
|
|
|
def extra_repr(self) -> str:
|
|
return '{num_groups}, {num_channels}, eps={eps}, ' \
|
|
'affine={affine}'.format(**self.__dict__)
|
|
|
|
|
|
# TODO: ContrastiveNorm2d
|
|
# TODO: DivisiveNorm2d
|
|
# TODO: SubtractiveNorm2d
|