mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: This PR was not my worst debugging annoyance, nor my smallest in lines changed, but it has the highest `debugging annoyance/lines changed` ratio. The current pattern ``` self.num_batches_tracked = self.num_batches_tracked + 1 ``` , if captured, deletes an eagerly-allocated tensor and overwrites it with a captured tensor. Replays read from the (deallocated) original tensor's address. This can cause 1. an IMA on graph replay 2. failure to actually increment `num_batches_tracked` during graph replay, because every replay reads from the old location without adding to it 3. numerical corruption if the allocator reassigns the original tensor's memory to some unrelated tensor 4. combinations of 1, 2, and 3, depending on global allocation patterns and if/when the BN module is called eagerly sometimes between replays (ask me how I know). Pull Request resolved: https://github.com/pytorch/pytorch/pull/70444 Reviewed By: albanD Differential Revision: D33342203 Pulled By: ngimel fbshipit-source-id: 5f201cc25030517e75af010bbaa88c452155df21
821 lines
36 KiB
Python
821 lines
36 KiB
Python
from typing import Optional, Any
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
|
|
|
|
from .. import functional as F
|
|
from .. import init
|
|
from ._functions import SyncBatchNorm as sync_batch_norm
|
|
from .lazy import LazyModuleMixin
|
|
from .module import Module
|
|
|
|
|
|
class _NormBase(Module):
|
|
"""Common base of _InstanceNorm and _BatchNorm"""
|
|
|
|
_version = 2
|
|
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
|
|
num_features: int
|
|
eps: float
|
|
momentum: float
|
|
affine: bool
|
|
track_running_stats: bool
|
|
# WARNING: weight and bias purposely not defined here.
|
|
# See https://github.com/pytorch/pytorch/issues/39670
|
|
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
eps: float = 1e-5,
|
|
momentum: float = 0.1,
|
|
affine: bool = True,
|
|
track_running_stats: bool = True,
|
|
device=None,
|
|
dtype=None
|
|
) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(_NormBase, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.affine = affine
|
|
self.track_running_stats = track_running_stats
|
|
if self.affine:
|
|
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
|
|
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
|
|
else:
|
|
self.register_parameter("weight", None)
|
|
self.register_parameter("bias", None)
|
|
if self.track_running_stats:
|
|
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
|
|
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
|
|
self.running_mean: Optional[Tensor]
|
|
self.running_var: Optional[Tensor]
|
|
self.register_buffer('num_batches_tracked',
|
|
torch.tensor(0, dtype=torch.long,
|
|
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
|
|
self.num_batches_tracked: Optional[Tensor]
|
|
else:
|
|
self.register_buffer("running_mean", None)
|
|
self.register_buffer("running_var", None)
|
|
self.register_buffer("num_batches_tracked", None)
|
|
self.reset_parameters()
|
|
|
|
def reset_running_stats(self) -> None:
|
|
if self.track_running_stats:
|
|
# running_mean/running_var/num_batches... are registered at runtime depending
|
|
# if self.track_running_stats is on
|
|
self.running_mean.zero_() # type: ignore[union-attr]
|
|
self.running_var.fill_(1) # type: ignore[union-attr]
|
|
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
|
|
|
|
def reset_parameters(self) -> None:
|
|
self.reset_running_stats()
|
|
if self.affine:
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def _check_input_dim(self, input):
|
|
raise NotImplementedError
|
|
|
|
def extra_repr(self):
|
|
return (
|
|
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
|
|
"track_running_stats={track_running_stats}".format(**self.__dict__)
|
|
)
|
|
|
|
def _load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
):
|
|
version = local_metadata.get("version", None)
|
|
|
|
if (version is None or version < 2) and self.track_running_stats:
|
|
# at version 2: added num_batches_tracked buffer
|
|
# this should have a default value of 0
|
|
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
if num_batches_tracked_key not in state_dict:
|
|
state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
|
|
|
|
super(_NormBase, self)._load_from_state_dict(
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
)
|
|
|
|
|
|
class _BatchNorm(_NormBase):
|
|
def __init__(
|
|
self,
|
|
num_features,
|
|
eps=1e-5,
|
|
momentum=0.1,
|
|
affine=True,
|
|
track_running_stats=True,
|
|
device=None,
|
|
dtype=None
|
|
):
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(_BatchNorm, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
|
)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
self._check_input_dim(input)
|
|
|
|
# exponential_average_factor is set to self.momentum
|
|
# (when it is available) only so that it gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
|
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
|
self.num_batches_tracked.add_(1) # type: ignore[has-type]
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
r"""
|
|
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
|
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
|
"""
|
|
if self.training:
|
|
bn_training = True
|
|
else:
|
|
bn_training = (self.running_mean is None) and (self.running_var is None)
|
|
|
|
r"""
|
|
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
|
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
|
used for normalization (i.e. in eval mode when buffers are not None).
|
|
"""
|
|
return F.batch_norm(
|
|
input,
|
|
# If buffers are not to be tracked, ensure that they won't be updated
|
|
self.running_mean
|
|
if not self.training or self.track_running_stats
|
|
else None,
|
|
self.running_var if not self.training or self.track_running_stats else None,
|
|
self.weight,
|
|
self.bias,
|
|
bn_training,
|
|
exponential_average_factor,
|
|
self.eps,
|
|
)
|
|
|
|
|
|
class _LazyNormBase(LazyModuleMixin, _NormBase):
|
|
|
|
weight: UninitializedParameter # type: ignore[assignment]
|
|
bias: UninitializedParameter # type: ignore[assignment]
|
|
|
|
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
|
device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(_LazyNormBase, self).__init__(
|
|
# affine and track_running_stats are hardcoded to False to
|
|
# avoid creating tensors that will soon be overwritten.
|
|
0,
|
|
eps,
|
|
momentum,
|
|
False,
|
|
False,
|
|
**factory_kwargs,
|
|
)
|
|
self.affine = affine
|
|
self.track_running_stats = track_running_stats
|
|
if self.affine:
|
|
self.weight = UninitializedParameter(**factory_kwargs)
|
|
self.bias = UninitializedParameter(**factory_kwargs)
|
|
if self.track_running_stats:
|
|
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
|
self.running_var = UninitializedBuffer(**factory_kwargs)
|
|
self.num_batches_tracked = torch.tensor(
|
|
0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
|
|
|
|
def reset_parameters(self) -> None:
|
|
if not self.has_uninitialized_params() and self.num_features != 0:
|
|
super().reset_parameters()
|
|
|
|
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
|
if self.has_uninitialized_params():
|
|
self.num_features = input.shape[1]
|
|
if self.affine:
|
|
assert isinstance(self.weight, UninitializedParameter)
|
|
assert isinstance(self.bias, UninitializedParameter)
|
|
self.weight.materialize((self.num_features,))
|
|
self.bias.materialize((self.num_features,))
|
|
if self.track_running_stats:
|
|
self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
|
|
self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
|
|
self.reset_parameters()
|
|
|
|
|
|
class BatchNorm1d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
|
inputs with optional additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
|
to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
|
|
via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics, and initializes statistics
|
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
|
When these buffers are ``None``, this module always uses batch statistics.
|
|
in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
|
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100, affine=False)
|
|
>>> input = torch.randn(20, 100)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 2 and input.dim() != 3:
|
|
raise ValueError(
|
|
"expected 2D or 3D input (got {}D input)".format(input.dim())
|
|
)
|
|
|
|
|
|
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
|
|
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of
|
|
the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
|
|
from the ``input.size(1)``.
|
|
The attributes that will be lazily initialized are `weight`, `bias`,
|
|
`running_mean` and `running_var`.
|
|
|
|
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
|
on lazy modules and their limitations.
|
|
|
|
Args:
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics, and initializes statistics
|
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
|
When these buffers are ``None``, this module always uses batch statistics.
|
|
in both training and eval modes. Default: ``True``
|
|
"""
|
|
|
|
cls_to_become = BatchNorm1d # type: ignore[assignment]
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 2 and input.dim() != 3:
|
|
raise ValueError(
|
|
"expected 2D or 3D input (got {}D input)".format(input.dim())
|
|
)
|
|
|
|
|
|
class BatchNorm2d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
|
to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
|
|
via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics, and initializes statistics
|
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
|
When these buffers are ``None``, this module always uses batch statistics.
|
|
in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 4:
|
|
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
|
|
|
|
|
|
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
|
|
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of
|
|
the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
|
|
from the ``input.size(1)``.
|
|
The attributes that will be lazily initialized are `weight`, `bias`,
|
|
`running_mean` and `running_var`.
|
|
|
|
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
|
on lazy modules and their limitations.
|
|
|
|
Args:
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics, and initializes statistics
|
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
|
When these buffers are ``None``, this module always uses batch statistics.
|
|
in both training and eval modes. Default: ``True``
|
|
"""
|
|
|
|
cls_to_become = BatchNorm2d # type: ignore[assignment]
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 4:
|
|
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
|
|
|
|
|
|
class BatchNorm3d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
|
to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
|
|
via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
|
|
or Spatio-temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, D, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics, and initializes statistics
|
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
|
When these buffers are ``None``, this module always uses batch statistics.
|
|
in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 5:
|
|
raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
|
|
|
|
|
|
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
|
|
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of
|
|
the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
|
|
from the ``input.size(1)``.
|
|
The attributes that will be lazily initialized are `weight`, `bias`,
|
|
`running_mean` and `running_var`.
|
|
|
|
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
|
on lazy modules and their limitations.
|
|
|
|
Args:
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics, and initializes statistics
|
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
|
When these buffers are ``None``, this module always uses batch statistics.
|
|
in both training and eval modes. Default: ``True``
|
|
"""
|
|
|
|
cls_to_become = BatchNorm3d # type: ignore[assignment]
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 5:
|
|
raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
|
|
|
|
|
|
class SyncBatchNorm(_BatchNorm):
|
|
r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over all
|
|
mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
|
|
are learnable parameter vectors of size `C` (where `C` is the input size).
|
|
By default, the elements of :math:`\gamma` are sampled from
|
|
:math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
The standard-deviation is calculated via the biased estimator, equivalent to
|
|
`torch.var(input, unbiased=False)`.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
|
|
statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
|
|
Normalization or Spatio-temporal Batch Normalization.
|
|
|
|
Currently :class:`SyncBatchNorm` only supports
|
|
:class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
|
|
:meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
|
|
:attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
|
|
Network with DDP.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, +)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: ``1e-5``
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics, and initializes statistics
|
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
|
When these buffers are ``None``, this module always uses batch statistics.
|
|
in both training and eval modes. Default: ``True``
|
|
process_group: synchronization of stats happen within each process group
|
|
individually. Default behavior is synchronization across the whole
|
|
world
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, +)`
|
|
- Output: :math:`(N, C, +)` (same shape as input)
|
|
|
|
.. note::
|
|
Synchronization of batchnorm statistics occurs only while training, i.e.
|
|
synchronization is disabled when ``model.eval()`` is set or if
|
|
``self.training`` is otherwise ``False``.
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.SyncBatchNorm(100)
|
|
>>> # creating process group (optional)
|
|
>>> # ranks is a list of int identifying rank ids.
|
|
>>> ranks = list(range(8))
|
|
>>> r1, r2 = ranks[:4], ranks[4:]
|
|
>>> # Note: every rank calls into new_group for every
|
|
>>> # process group created, even if that rank is not
|
|
>>> # part of the group.
|
|
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
|
|
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
|
|
>>> # network is nn.BatchNorm layer
|
|
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
|
|
>>> # only single gpu per process is currently supported
|
|
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
|
|
>>> sync_bn_network,
|
|
>>> device_ids=[args.local_rank],
|
|
>>> output_device=args.local_rank)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
eps: float = 1e-5,
|
|
momentum: float = 0.1,
|
|
affine: bool = True,
|
|
track_running_stats: bool = True,
|
|
process_group: Optional[Any] = None,
|
|
device=None,
|
|
dtype=None
|
|
) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(SyncBatchNorm, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
|
)
|
|
self.process_group = process_group
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() < 2:
|
|
raise ValueError(
|
|
"expected at least 2D input (got {}D input)".format(input.dim())
|
|
)
|
|
|
|
def _check_non_zero_input_channels(self, input):
|
|
if input.size(1) == 0:
|
|
raise ValueError(
|
|
"SyncBatchNorm number of input channels should be non-zero"
|
|
)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
# currently only GPU input is supported
|
|
if not input.is_cuda:
|
|
raise ValueError("SyncBatchNorm expected input tensor to be on GPU")
|
|
|
|
self._check_input_dim(input)
|
|
self._check_non_zero_input_channels(input)
|
|
|
|
# exponential_average_factor is set to self.momentum
|
|
# (when it is available) only so that it gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
assert self.num_batches_tracked is not None
|
|
self.num_batches_tracked.add_(1)
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
r"""
|
|
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
|
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
|
"""
|
|
if self.training:
|
|
bn_training = True
|
|
else:
|
|
bn_training = (self.running_mean is None) and (self.running_var is None)
|
|
|
|
r"""
|
|
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
|
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
|
used for normalization (i.e. in eval mode when buffers are not None).
|
|
"""
|
|
# If buffers are not to be tracked, ensure that they won't be updated
|
|
running_mean = (
|
|
self.running_mean if not self.training or self.track_running_stats else None
|
|
)
|
|
running_var = (
|
|
self.running_var if not self.training or self.track_running_stats else None
|
|
)
|
|
|
|
# Don't sync batchnorm stats in inference mode (model.eval()).
|
|
need_sync = (bn_training and self.training)
|
|
if need_sync:
|
|
process_group = torch.distributed.group.WORLD
|
|
if self.process_group:
|
|
process_group = self.process_group
|
|
world_size = torch.distributed.get_world_size(process_group)
|
|
need_sync = world_size > 1
|
|
|
|
# fallback to framework BN when synchronization is not necessary
|
|
if not need_sync:
|
|
return F.batch_norm(
|
|
input,
|
|
running_mean,
|
|
running_var,
|
|
self.weight,
|
|
self.bias,
|
|
bn_training,
|
|
exponential_average_factor,
|
|
self.eps,
|
|
)
|
|
else:
|
|
assert bn_training
|
|
return sync_batch_norm.apply(
|
|
input,
|
|
self.weight,
|
|
self.bias,
|
|
running_mean,
|
|
running_var,
|
|
self.eps,
|
|
exponential_average_factor,
|
|
process_group,
|
|
world_size,
|
|
)
|
|
|
|
@classmethod
|
|
def convert_sync_batchnorm(cls, module, process_group=None):
|
|
r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
|
|
:class:`torch.nn.SyncBatchNorm` layers.
|
|
|
|
Args:
|
|
module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
|
|
process_group (optional): process group to scope synchronization,
|
|
default is the whole world
|
|
|
|
Returns:
|
|
The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
|
|
layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
|
|
a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
|
|
instead.
|
|
|
|
Example::
|
|
|
|
>>> # Network with nn.BatchNorm layer
|
|
>>> module = torch.nn.Sequential(
|
|
>>> torch.nn.Linear(20, 100),
|
|
>>> torch.nn.BatchNorm1d(100),
|
|
>>> ).cuda()
|
|
>>> # creating process group (optional)
|
|
>>> # ranks is a list of int identifying rank ids.
|
|
>>> ranks = list(range(8))
|
|
>>> r1, r2 = ranks[:4], ranks[4:]
|
|
>>> # Note: every rank calls into new_group for every
|
|
>>> # process group created, even if that rank is not
|
|
>>> # part of the group.
|
|
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
|
|
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
|
|
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
|
|
|
|
"""
|
|
module_output = module
|
|
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
|
module_output = torch.nn.SyncBatchNorm(
|
|
module.num_features,
|
|
module.eps,
|
|
module.momentum,
|
|
module.affine,
|
|
module.track_running_stats,
|
|
process_group,
|
|
)
|
|
if module.affine:
|
|
with torch.no_grad():
|
|
module_output.weight = module.weight
|
|
module_output.bias = module.bias
|
|
module_output.running_mean = module.running_mean
|
|
module_output.running_var = module.running_var
|
|
module_output.num_batches_tracked = module.num_batches_tracked
|
|
if hasattr(module, "qconfig"):
|
|
module_output.qconfig = module.qconfig
|
|
for name, child in module.named_children():
|
|
module_output.add_module(
|
|
name, cls.convert_sync_batchnorm(child, process_group)
|
|
)
|
|
del module
|
|
return module_output
|