mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR fixes typo in `.rst` files under docs directory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97202 Approved by: https://github.com/kit1980
84 lines
3.0 KiB
ReStructuredText
84 lines
3.0 KiB
ReStructuredText
Patching Batch Norm
|
|
===================
|
|
|
|
What's happening?
|
|
-----------------
|
|
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
|
|
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
|
|
``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module,
|
|
we end up with this error
|
|
|
|
How to fix
|
|
----------
|
|
One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this
|
|
|
|
All of these options assume that you don't need running stats. If you're using a module this means
|
|
that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
|
|
running batch norm with vmap in evaluation mode, please file an issue
|
|
|
|
Option 1: Change the BatchNorm
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:
|
|
|
|
.. code-block:: python
|
|
|
|
BatchNorm2d(C, G, track_running_stats=False)
|
|
|
|
Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to
|
|
break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning
|
|
each channel will be treated separately.
|
|
|
|
If you must use BatchNorm and you've built the module yourself, you can change the module to
|
|
not use running stats. In other words, anywhere that there's a BatchNorm module, set the
|
|
``track_running_stats`` flag to be False
|
|
|
|
.. code-block:: python
|
|
|
|
BatchNorm2d(64, track_running_stats=False)
|
|
|
|
|
|
Option 2: torchvision parameter
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
|
|
often defaulted to be BatchNorm2d if they've been defaulted.
|
|
|
|
Instead you can set it to be GroupNorm.
|
|
|
|
.. code-block:: python
|
|
|
|
import torchvision
|
|
from functools import partial
|
|
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
|
|
|
|
Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``.
|
|
|
|
If you are attached to BatchNorm, be sure to use a version that doesn't use running stats
|
|
|
|
.. code-block:: python
|
|
|
|
import torchvision
|
|
from functools import partial
|
|
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
|
|
|
|
Option 3: functorch's patching
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
functorch has added some functionality to allow for quick, in-place patching of the module to not
|
|
use running stats. Changing the norm layer is more fragile, so we have not offered that. If you
|
|
have a net where you want the BatchNorm to not use running stats, you can run
|
|
``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats
|
|
|
|
.. code-block:: python
|
|
|
|
from torch.func import replace_all_batch_norm_modules_
|
|
replace_all_batch_norm_modules_(net)
|
|
|
|
Option 4: eval mode
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode
|
|
|
|
.. code-block:: python
|
|
|
|
model.eval()
|
|
vmap(model)(x)
|
|
model.train()
|