[functorch] Add Batch Norm module utilities to not track running stats (pytorch/functorch#505)

* adds converter and tests, show omniglot wrong

* fix example to use batch norm, test to use group norm
This commit is contained in:
Samantha Andow
2022-03-08 11:49:34 -05:00
committed by Jon Janzen
parent 664cd284c5
commit 0af3636257
4 changed files with 72 additions and 24 deletions

View File

@ -85,22 +85,21 @@ def main():
)
# Create a vanilla PyTorch neural network.
# TODO (samdow): fix batch norm support
inplace_relu = True
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.GroupNorm(32, 64, affine=True),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.GroupNorm(32, 64, affine=True),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.GroupNorm(32, 64, affine=True),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Flatten(),
nn.Linear(64, args.n_way)).to(device)
net.train()
@ -260,12 +259,5 @@ def plot(log):
plt.close(fig)
# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
if __name__ == '__main__':
main()

View File

@ -1,2 +1 @@
# PyTorch forward-mode is not mature yet
from .._src.eager_transforms import jvp, jacfwd, hessian
from .batch_norm_replacement import replace_all_batch_norm_modules_

View File

@ -0,0 +1,22 @@
import torch.nn as nn
def batch_norm_without_running_stats(module: nn.Module):
if isinstance(module, nn.modules.batchnorm._BatchNorm) and module.track_running_stats:
module.running_mean = None
module.running_var = None
module.num_batches_tracked = None
module.track_running_stats = False
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module:
"""
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and
setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root`
"""
# base case
batch_norm_without_running_stats(root)
for obj in root.modules():
batch_norm_without_running_stats(obj)
return root

View File

@ -17,18 +17,16 @@ import math
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU
from torch.testing._internal.common_dtype import get_all_fp_dtypes
from functools import partial
from functorch.experimental import replace_all_batch_norm_modules_
import functorch
from functorch import (
grad, vjp, vmap, jacrev, grad_and_value,
make_functional, make_functional_with_buffers,
grad, vjp, vmap, jacrev, jacfwd, grad_and_value, hessian,
jvp, make_functional, make_functional_with_buffers,
)
from functorch._src.make_functional import (
functional_init, functional_init_with_buffers,
)
from functorch.experimental import (
jvp, jacfwd, hessian,
)
from functorch._src.eager_transforms import _argnums_partial
from functorch._src.custom_function import custom_vjp
@ -2109,10 +2107,8 @@ class TestExamplesCorrectness(TestCase):
n_inner_iter = 2
num_tasks = 2
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
# real example uses batch norm but it's numerically unstable in the first
# iteration, when near 0, and won't produce same gradients. Uses group norm instead
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.GroupNorm(64, 64, affine=True),
@ -2126,7 +2122,7 @@ class TestExamplesCorrectness(TestCase):
nn.GroupNorm(64, 64, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Flatten(),
nn.Linear(64, n_way)).to(device).to(dtype)
fnet, params, buffers = make_functional_with_buffers(net)
@ -2176,6 +2172,45 @@ class TestExamplesCorrectness(TestCase):
self.assertEqual(result_grads, expected_grads)
@parametrize('originally_track_running_stats', [True, False])
def test_update_batch_norm(self, device, originally_track_running_stats):
dtype = torch.double
inplace_relu = False
classes = 5
num_batches = 2
net = nn.Sequential(
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=originally_track_running_stats),
nn.ReLU(inplace=inplace_relu),
nn.Flatten(),
nn.Linear(43264, classes)).to(device).to(dtype)
replace_all_batch_norm_modules_(net)
transformed_net = net
fnet, params, buffers = make_functional_with_buffers(transformed_net)
net = (params, buffers, fnet)
criterion = nn.CrossEntropyLoss()
def compute_loss(x, y, params, buffers):
return criterion(fnet(params, buffers, x), y)
# Get some sample inputs...
x = torch.randn(num_batches, 1, 64, 28, 28, device=device, dtype=dtype)
y = torch.randint(0, classes, (num_batches, 1), device=device)
# compute some per sample grads with vmap + grad
result_grads = vmap(grad(compute_loss, argnums=2), in_dims=(0, 0, None, None))(x, y, params, buffers)
# compute some per sample grads without vmap + grad
fnet, params, buffers = make_functional_with_buffers(transformed_net)
expected_grads = [
torch.autograd.grad(compute_loss(x[i], y[i], params, buffers), params)
for i in range(num_batches)
]
expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]
self.assertEqual(result_grads, expected_grads)
def test_lennard_jones_batched_jacrev(self, device):
sigma = 0.5
epsilon = 4.