mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: xref gh-32838, gh-34032 This is a major refactor of parts of the documentation to split it up using sphinx's `autosummary` feature which will build out `autofuction` and `autoclass` stub files and link to them. The end result is that the top module pages like torch.nn.rst and torch.rst are now more like table-of-contents to the actual single-class or single-function documentations pages. Along the way, I modified many of the docstrings to eliminate sphinx warnings when building. I think the only thing I changed from a non-documentation perspective is to add names to `__all__` when adding them to `globals()` in `torch.__init__.py` I do not know the CI system: are the documentation build artifacts available after the build, so reviewers can preview before merging? Pull Request resolved: https://github.com/pytorch/pytorch/pull/37419 Differential Revision: D21337640 Pulled By: ezyang fbshipit-source-id: d4ad198780c3ae7a96a9f22651e00ff2d31a0c0f
513 lines
23 KiB
Python
513 lines
23 KiB
Python
from __future__ import division
|
|
|
|
import torch
|
|
from ._functions import SyncBatchNorm as sync_batch_norm
|
|
from .module import Module
|
|
from torch.nn.parameter import Parameter
|
|
from .. import functional as F
|
|
from .. import init
|
|
|
|
|
|
class _NormBase(Module):
|
|
"""Common base of _InstanceNorm and _BatchNorm"""
|
|
_version = 2
|
|
__constants__ = ['track_running_stats', 'momentum', 'eps',
|
|
'num_features', 'affine']
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
track_running_stats=True):
|
|
super(_NormBase, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.affine = affine
|
|
self.track_running_stats = track_running_stats
|
|
if self.affine:
|
|
self.weight = Parameter(torch.Tensor(num_features))
|
|
self.bias = Parameter(torch.Tensor(num_features))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
if self.track_running_stats:
|
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
self.register_buffer('running_var', torch.ones(num_features))
|
|
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
|
else:
|
|
self.register_parameter('running_mean', None)
|
|
self.register_parameter('running_var', None)
|
|
self.register_parameter('num_batches_tracked', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_running_stats(self):
|
|
if self.track_running_stats:
|
|
self.running_mean.zero_()
|
|
self.running_var.fill_(1)
|
|
self.num_batches_tracked.zero_()
|
|
|
|
def reset_parameters(self):
|
|
self.reset_running_stats()
|
|
if self.affine:
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def _check_input_dim(self, input):
|
|
raise NotImplementedError
|
|
|
|
def extra_repr(self):
|
|
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
|
|
'track_running_stats={track_running_stats}'.format(**self.__dict__)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
version = local_metadata.get('version', None)
|
|
|
|
if (version is None or version < 2) and self.track_running_stats:
|
|
# at version 2: added num_batches_tracked buffer
|
|
# this should have a default value of 0
|
|
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
|
if num_batches_tracked_key not in state_dict:
|
|
state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
|
|
|
|
super(_NormBase, self)._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
|
class _BatchNorm(_NormBase):
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
track_running_stats=True):
|
|
super(_BatchNorm, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats)
|
|
|
|
def forward(self, input):
|
|
self._check_input_dim(input)
|
|
|
|
# exponential_average_factor is set to self.momentum
|
|
# (when it is available) only so that it gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
|
if self.num_batches_tracked is not None:
|
|
self.num_batches_tracked = self.num_batches_tracked + 1
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
return F.batch_norm(
|
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
self.training or not self.track_running_stats,
|
|
exponential_average_factor, self.eps)
|
|
|
|
|
|
class BatchNorm1d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
|
inputs with optional additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
|
to 1 and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
|
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100, affine=False)
|
|
>>> input = torch.randn(20, 100)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 2 and input.dim() != 3:
|
|
raise ValueError('expected 2D or 3D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class BatchNorm2d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
|
to 1 and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 4:
|
|
raise ValueError('expected 4D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class BatchNorm3d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
|
to 1 and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
|
|
or Spatio-temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, D, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 5:
|
|
raise ValueError('expected 5D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class SyncBatchNorm(_BatchNorm):
|
|
r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over all
|
|
mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
|
|
are learnable parameter vectors of size `C` (where `C` is the input size).
|
|
By default, the elements of :math:`\gamma` are sampled from
|
|
:math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization
|
|
or Spatio-temporal Batch Normalization.
|
|
|
|
Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
|
|
torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
|
|
Network with DDP.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, +)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
process_group: synchronization of stats happen within each process group
|
|
individually. Default behavior is synchronization across the whole
|
|
world
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, +)`
|
|
- Output: :math:`(N, C, +)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.SyncBatchNorm(100)
|
|
>>> # creating process group (optional)
|
|
>>> # process_ids is a list of int identifying rank ids.
|
|
>>> process_group = torch.distributed.new_group(process_ids)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
|
|
>>> # network is nn.BatchNorm layer
|
|
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
|
|
>>> # only single gpu per process is currently supported
|
|
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
|
|
>>> sync_bn_network,
|
|
>>> device_ids=[args.local_rank],
|
|
>>> output_device=args.local_rank)
|
|
"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
track_running_stats=True, process_group=None):
|
|
super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats)
|
|
self.process_group = process_group
|
|
# gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
|
|
# under supported condition (single GPU per process)
|
|
self.ddp_gpu_size = None
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() < 2:
|
|
raise ValueError('expected at least 2D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
def _specify_ddp_gpu_num(self, gpu_size):
|
|
if gpu_size > 1:
|
|
raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
|
|
self.ddp_gpu_size = gpu_size
|
|
|
|
def forward(self, input):
|
|
# currently only GPU input is supported
|
|
if not input.is_cuda:
|
|
raise ValueError('SyncBatchNorm expected input tensor to be on GPU')
|
|
|
|
self._check_input_dim(input)
|
|
|
|
# exponential_average_factor is set to self.momentum
|
|
# (when it is available) only so that it gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
self.num_batches_tracked = self.num_batches_tracked + 1
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
need_sync = self.training or not self.track_running_stats
|
|
if need_sync:
|
|
process_group = torch.distributed.group.WORLD
|
|
if self.process_group:
|
|
process_group = self.process_group
|
|
world_size = torch.distributed.get_world_size(process_group)
|
|
need_sync = world_size > 1
|
|
|
|
# fallback to framework BN when synchronization is not necessary
|
|
if not need_sync:
|
|
return F.batch_norm(
|
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
self.training or not self.track_running_stats,
|
|
exponential_average_factor, self.eps)
|
|
else:
|
|
if not self.ddp_gpu_size:
|
|
raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
|
|
|
|
return sync_batch_norm.apply(
|
|
input, self.weight, self.bias, self.running_mean, self.running_var,
|
|
self.eps, exponential_average_factor, process_group, world_size)
|
|
|
|
@classmethod
|
|
def convert_sync_batchnorm(cls, module, process_group=None):
|
|
r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
|
|
`torch.nn.SyncBatchNorm` layer.
|
|
|
|
Args:
|
|
module (nn.Module): containing module
|
|
process_group (optional): process group to scope synchronization,
|
|
default is the whole world
|
|
|
|
Returns:
|
|
The original module with the converted `torch.nn.SyncBatchNorm` layer
|
|
|
|
Example::
|
|
|
|
>>> # Network with nn.BatchNorm layer
|
|
>>> module = torch.nn.Sequential(
|
|
>>> torch.nn.Linear(20, 100),
|
|
>>> torch.nn.BatchNorm1d(100)
|
|
>>> ).cuda()
|
|
>>> # creating process group (optional)
|
|
>>> # process_ids is a list of int identifying rank ids.
|
|
>>> process_group = torch.distributed.new_group(process_ids)
|
|
>>> sync_bn_module = convert_sync_batchnorm(module, process_group)
|
|
|
|
"""
|
|
module_output = module
|
|
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
|
module_output = torch.nn.SyncBatchNorm(module.num_features,
|
|
module.eps, module.momentum,
|
|
module.affine,
|
|
module.track_running_stats,
|
|
process_group)
|
|
if module.affine:
|
|
with torch.no_grad():
|
|
module_output.weight.copy_(module.weight)
|
|
module_output.bias.copy_(module.bias)
|
|
# keep requires_grad unchanged
|
|
module_output.weight.requires_grad = module.weight.requires_grad
|
|
module_output.bias.requires_grad = module.bias.requires_grad
|
|
module_output.running_mean = module.running_mean
|
|
module_output.running_var = module.running_var
|
|
module_output.num_batches_tracked = module.num_batches_tracked
|
|
for name, child in module.named_children():
|
|
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
|
|
del module
|
|
return module_output
|