mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Add Batch Norm module utilities to not track running stats (pytorch/functorch#505)
* adds converter and tests, show omniglot wrong * fix example to use batch norm, test to use group norm
This commit is contained in:
committed by
Jon Janzen
parent
664cd284c5
commit
0af3636257
@ -85,22 +85,21 @@ def main():
|
||||
)
|
||||
|
||||
# Create a vanilla PyTorch neural network.
|
||||
# TODO (samdow): fix batch norm support
|
||||
inplace_relu = True
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(1, 64, 3),
|
||||
nn.GroupNorm(32, 64, affine=True),
|
||||
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.GroupNorm(32, 64, affine=True),
|
||||
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.GroupNorm(32, 64, affine=True),
|
||||
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(64, args.n_way)).to(device)
|
||||
|
||||
net.train()
|
||||
@ -260,12 +259,5 @@ def plot(log):
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# Won't need this after this PR is merged in:
|
||||
# https://github.com/pytorch/pytorch/pull/22245
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -1,2 +1 @@
|
||||
# PyTorch forward-mode is not mature yet
|
||||
from .._src.eager_transforms import jvp, jacfwd, hessian
|
||||
from .batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
|
22
functorch/functorch/experimental/batch_norm_replacement.py
Normal file
22
functorch/functorch/experimental/batch_norm_replacement.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def batch_norm_without_running_stats(module: nn.Module):
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) and module.track_running_stats:
|
||||
module.running_mean = None
|
||||
module.running_var = None
|
||||
module.num_batches_tracked = None
|
||||
module.track_running_stats = False
|
||||
|
||||
|
||||
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module:
|
||||
"""
|
||||
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and
|
||||
setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root`
|
||||
"""
|
||||
# base case
|
||||
batch_norm_without_running_stats(root)
|
||||
|
||||
for obj in root.modules():
|
||||
batch_norm_without_running_stats(obj)
|
||||
return root
|
@ -17,18 +17,16 @@ import math
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU
|
||||
from torch.testing._internal.common_dtype import get_all_fp_dtypes
|
||||
from functools import partial
|
||||
from functorch.experimental import replace_all_batch_norm_modules_
|
||||
|
||||
import functorch
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev, grad_and_value,
|
||||
make_functional, make_functional_with_buffers,
|
||||
grad, vjp, vmap, jacrev, jacfwd, grad_and_value, hessian,
|
||||
jvp, make_functional, make_functional_with_buffers,
|
||||
)
|
||||
from functorch._src.make_functional import (
|
||||
functional_init, functional_init_with_buffers,
|
||||
)
|
||||
from functorch.experimental import (
|
||||
jvp, jacfwd, hessian,
|
||||
)
|
||||
from functorch._src.eager_transforms import _argnums_partial
|
||||
from functorch._src.custom_function import custom_vjp
|
||||
|
||||
@ -2109,10 +2107,8 @@ class TestExamplesCorrectness(TestCase):
|
||||
n_inner_iter = 2
|
||||
num_tasks = 2
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
# real example uses batch norm but it's numerically unstable in the first
|
||||
# iteration, when near 0, and won't produce same gradients. Uses group norm instead
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(1, 64, 3),
|
||||
nn.GroupNorm(64, 64, affine=True),
|
||||
@ -2126,7 +2122,7 @@ class TestExamplesCorrectness(TestCase):
|
||||
nn.GroupNorm(64, 64, affine=True),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(64, n_way)).to(device).to(dtype)
|
||||
|
||||
fnet, params, buffers = make_functional_with_buffers(net)
|
||||
@ -2176,6 +2172,45 @@ class TestExamplesCorrectness(TestCase):
|
||||
|
||||
self.assertEqual(result_grads, expected_grads)
|
||||
|
||||
@parametrize('originally_track_running_stats', [True, False])
|
||||
def test_update_batch_norm(self, device, originally_track_running_stats):
|
||||
dtype = torch.double
|
||||
inplace_relu = False
|
||||
classes = 5
|
||||
num_batches = 2
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.BatchNorm2d(64, affine=True, track_running_stats=originally_track_running_stats),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.Flatten(),
|
||||
nn.Linear(43264, classes)).to(device).to(dtype)
|
||||
|
||||
replace_all_batch_norm_modules_(net)
|
||||
transformed_net = net
|
||||
fnet, params, buffers = make_functional_with_buffers(transformed_net)
|
||||
net = (params, buffers, fnet)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
def compute_loss(x, y, params, buffers):
|
||||
return criterion(fnet(params, buffers, x), y)
|
||||
|
||||
# Get some sample inputs...
|
||||
x = torch.randn(num_batches, 1, 64, 28, 28, device=device, dtype=dtype)
|
||||
y = torch.randint(0, classes, (num_batches, 1), device=device)
|
||||
|
||||
# compute some per sample grads with vmap + grad
|
||||
result_grads = vmap(grad(compute_loss, argnums=2), in_dims=(0, 0, None, None))(x, y, params, buffers)
|
||||
|
||||
# compute some per sample grads without vmap + grad
|
||||
fnet, params, buffers = make_functional_with_buffers(transformed_net)
|
||||
expected_grads = [
|
||||
torch.autograd.grad(compute_loss(x[i], y[i], params, buffers), params)
|
||||
for i in range(num_batches)
|
||||
]
|
||||
expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]
|
||||
|
||||
self.assertEqual(result_grads, expected_grads)
|
||||
|
||||
def test_lennard_jones_batched_jacrev(self, device):
|
||||
sigma = 0.5
|
||||
epsilon = 4.
|
||||
|
Reference in New Issue
Block a user