mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: xref gh-32838, gh-34032 This is a major refactor of parts of the documentation to split it up using sphinx's `autosummary` feature which will build out `autofuction` and `autoclass` stub files and link to them. The end result is that the top module pages like torch.nn.rst and torch.rst are now more like table-of-contents to the actual single-class or single-function documentations pages. Along the way, I modified many of the docstrings to eliminate sphinx warnings when building. I think the only thing I changed from a non-documentation perspective is to add names to `__all__` when adding them to `globals()` in `torch.__init__.py` I do not know the CI system: are the documentation build artifacts available after the build, so reviewers can preview before merging? Pull Request resolved: https://github.com/pytorch/pytorch/pull/37419 Differential Revision: D21337640 Pulled By: ezyang fbshipit-source-id: d4ad198780c3ae7a96a9f22651e00ff2d31a0c0f
90 lines
3.6 KiB
Python
90 lines
3.6 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import torch
|
|
import torch.nn.quantized.functional
|
|
|
|
class LayerNorm(torch.nn.LayerNorm):
|
|
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
|
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
|
|
|
|
.. math::
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated separately over the last
|
|
certain number dimensions which have to be of the shape specified by
|
|
:attr:`normalized_shape`.
|
|
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
|
|
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
|
|
|
|
.. note::
|
|
Unlike Batch Normalization and Instance Normalization, which applies
|
|
scalar scale and bias for each entire channel/plane with the
|
|
:attr:`affine` option, Layer Normalization applies per-element scale and
|
|
bias with :attr:`elementwise_affine`.
|
|
|
|
This layer uses statistics computed from input data in both training and
|
|
evaluation modes.
|
|
|
|
Args:
|
|
normalized_shape (int or list or torch.Size): input shape from an expected input
|
|
of size
|
|
|
|
.. math::
|
|
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
|
|
\times \ldots \times \text{normalized\_shape}[-1]]
|
|
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
elementwise_affine: a boolean value that when set to ``True``, this module
|
|
has learnable per-element affine parameters initialized to ones (for weights)
|
|
and zeros (for biases). Default: ``True``.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)`
|
|
- Output: :math:`(N, *)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(20, 5, 10, 10)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.LayerNorm(input.size()[1:])
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
|
|
>>> # Normalize over last two dimensions
|
|
>>> m = nn.LayerNorm([10, 10])
|
|
>>> # Normalize over last dimension of size 10
|
|
>>> m = nn.LayerNorm(10)
|
|
>>> # Activating the module
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
|
|
elementwise_affine=True):
|
|
super(LayerNorm, self).__init__(
|
|
normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
|
|
self.weight = weight
|
|
self.bias = bias
|
|
self.scale = scale
|
|
self.zero_point = zero_point
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.layer_norm(
|
|
input, self.normalized_shape, weight=self.weight, bias=self.bias,
|
|
eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedLayerNorm'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
activation_post_process = mod.activation_post_process
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
|
new_mod = cls(
|
|
mod.normalized_shape, mod.weight, mod.bias, float(scale),
|
|
int(zero_point), mod.eps, mod.elementwise_affine)
|
|
return new_mod
|