Files
pytorch/torch/nn/init.py
Aaron Gokaslan 88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00

568 lines
20 KiB
Python

import math
import warnings
from torch import Tensor
import torch
from typing import Optional as _Optional
# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
# managers, so these need to be implemented as builtins. Using these wrappers
# lets us keep those builtins small and re-usable.
def _no_grad_uniform_(tensor, a, b):
with torch.no_grad():
return tensor.uniform_(a, b)
def _no_grad_normal_(tensor, mean, std):
with torch.no_grad():
return tensor.normal_(mean, std)
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def _no_grad_fill_(tensor, val):
with torch.no_grad():
return tensor.fill_(val)
def _no_grad_zero_(tensor):
with torch.no_grad():
return tensor.zero_()
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
SELU :math:`\frac{3}{4}`
================= ====================================================
.. warning::
In order to implement `Self-Normalizing Neural Networks`_ ,
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
This gives the initial weights a variance of ``1 / N``,
which is necessary to induce a stable fixed point in the forward pass.
In contrast, the default gain for ``SELU`` sacrifices the normalization
effect for more stable gradient flow in rectangular layers.
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError(f"negative_slope {param} not a valid number")
return math.sqrt(2.0 / (1 + negative_slope ** 2))
elif nonlinearity == 'selu':
return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
else:
raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
r"""Fills the input Tensor with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
return _no_grad_uniform_(tensor, a, b)
def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
r"""Fills the input Tensor with values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
return _no_grad_normal_(tensor, mean, std)
def trunc_normal_(
tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.,
generator: _Optional[torch.Generator] = None
) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
def constant_(tensor: Tensor, val: float) -> Tensor:
r"""Fills the input Tensor with the value :math:`\text{val}`.
Args:
tensor: an n-dimensional `torch.Tensor`
val: the value to fill the tensor with
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
return _no_grad_fill_(tensor, val)
def ones_(tensor: Tensor) -> Tensor:
r"""Fills the input Tensor with the scalar value `1`.
Args:
tensor: an n-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.ones_(w)
"""
return _no_grad_fill_(tensor, 1.)
def zeros_(tensor: Tensor) -> Tensor:
r"""Fills the input Tensor with the scalar value `0`.
Args:
tensor: an n-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.zeros_(w)
"""
return _no_grad_zero_(tensor)
def eye_(tensor):
r"""Fills the 2-dimensional input `Tensor` with the identity
matrix. Preserves the identity of the inputs in `Linear` layers, where as
many inputs are preserved as possible.
Args:
tensor: a 2-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.eye_(w)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
with torch.no_grad():
torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
return tensor
def dirac_(tensor, groups=1):
r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac
delta function. Preserves the identity of the inputs in `Convolutional`
layers, where as many input channels are preserved as possible. In case
of groups>1, each group of channels preserves identity
Args:
tensor: a {3, 4, 5}-dimensional `torch.Tensor`
groups (int, optional): number of groups in the conv layer (default: 1)
Examples:
>>> w = torch.empty(3, 16, 5, 5)
>>> nn.init.dirac_(w)
>>> w = torch.empty(3, 24, 5, 5)
>>> nn.init.dirac_(w, 3)
"""
dimensions = tensor.ndimension()
if dimensions not in [3, 4, 5]:
raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
sizes = tensor.size()
if sizes[0] % groups != 0:
raise ValueError('dim 0 must be divisible by groups')
out_chans_per_grp = sizes[0] // groups
min_dim = min(out_chans_per_grp, sizes[1])
with torch.no_grad():
tensor.zero_()
for g in range(groups):
for d in range(min_dim):
if dimensions == 3: # Temporal convolution
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
elif dimensions == 4: # Spatial convolution
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
tensor.size(3) // 2] = 1
else: # Volumetric convolution
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
tensor.size(3) // 2, tensor.size(4) // 2] = 1
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in tensor.shape[2:]:
receptive_field_size *= s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
r"""Fills the input `Tensor` with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return _no_grad_uniform_(tensor, -a, a)
def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
r"""Fills the input `Tensor` with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_normal_(w)
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
return _no_grad_normal_(tensor, 0., std)
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(
tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
kaiming_uniform_,
(tensor,),
tensor=tensor,
a=a,
mode=mode,
nonlinearity=nonlinearity)
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def kaiming_normal_(
tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
normal distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
"""
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with torch.no_grad():
return tensor.normal_(0, std)
def orthogonal_(tensor, gain=1):
r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as
described in `Exact solutions to the nonlinear dynamics of learning in deep
linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
at least 2 dimensions, and for tensors with more than 2 dimensions the
trailing dimensions are flattened.
Args:
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
gain: optional scaling factor
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> w = torch.empty(3, 5)
>>> nn.init.orthogonal_(w)
"""
if tensor.ndimension() < 2:
raise ValueError("Only tensors with 2 or more dimensions are supported")
if tensor.numel() == 0:
# no-op
return tensor
rows = tensor.size(0)
cols = tensor.numel() // rows
flattened = tensor.new(rows, cols).normal_(0, 1)
if rows < cols:
flattened.t_()
# Compute the qr factorization
q, r = torch.linalg.qr(flattened)
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
d = torch.diag(r, 0)
ph = d.sign()
q *= ph
if rows < cols:
q.t_()
with torch.no_grad():
tensor.view_as(q).copy_(q)
tensor.mul_(gain)
return tensor
def sparse_(tensor, sparsity, std=0.01):
r"""Fills the 2D input `Tensor` as a sparse matrix, where the
non-zero elements will be drawn from the normal distribution
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
Hessian-free optimization` - Martens, J. (2010).
Args:
tensor: an n-dimensional `torch.Tensor`
sparsity: The fraction of elements in each column to be set to zero
std: the standard deviation of the normal distribution used to generate
the non-zero values
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.sparse_(w, sparsity=0.1)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
rows, cols = tensor.shape
num_zeros = int(math.ceil(sparsity * rows))
with torch.no_grad():
tensor.normal_(0, std)
for col_idx in range(cols):
row_indices = torch.randperm(rows)
zero_indices = row_indices[:num_zeros]
tensor[zero_indices, col_idx] = 0
return tensor
# for backward compatibility
def _make_deprecate(meth):
new_name = meth.__name__
old_name = new_name[:-1]
def deprecated_init(*args, **kwargs):
warnings.warn(f"nn.init.{old_name} is now deprecated in favor of nn.init.{new_name}.", stacklevel=2)
return meth(*args, **kwargs)
deprecated_init.__doc__ = fr"""
{old_name}(...)
.. warning::
This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
See :func:`~torch.nn.init.{new_name}` for details."""
deprecated_init.__name__ = old_name
return deprecated_init
uniform = _make_deprecate(uniform_)
normal = _make_deprecate(normal_)
constant = _make_deprecate(constant_)
eye = _make_deprecate(eye_)
dirac = _make_deprecate(dirac_)
xavier_uniform = _make_deprecate(xavier_uniform_)
xavier_normal = _make_deprecate(xavier_normal_)
kaiming_uniform = _make_deprecate(kaiming_uniform_)
kaiming_normal = _make_deprecate(kaiming_normal_)
orthogonal = _make_deprecate(orthogonal_)
sparse = _make_deprecate(sparse_)