use reset_running_stats in swa_utils.update_bn (#103801)

the stat reset in `swa_utils.update_bn` already exists in `NormBase.reset_running_stats`, so use that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103801
Approved by: https://github.com/janeyx99
This commit is contained in:
Niklas Nolte
2023-06-23 01:17:09 +00:00
committed by PyTorch MergeBot
parent 75716fb060
commit 4624afaa30

View File

@ -252,8 +252,7 @@ def update_bn(loader, model, device=None):
momenta = {}
for module in model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
module.reset_running_stats()
momenta[module] = module.momentum
if not momenta:
@ -263,7 +262,6 @@ def update_bn(loader, model, device=None):
model.train()
for module in momenta.keys():
module.momentum = None
module.num_batches_tracked *= 0
for input in loader:
if isinstance(input, (list, tuple)):