mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use reset_running_stats in swa_utils.update_bn (#103801)
the stat reset in `swa_utils.update_bn` already exists in `NormBase.reset_running_stats`, so use that Pull Request resolved: https://github.com/pytorch/pytorch/pull/103801 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
75716fb060
commit
4624afaa30
@ -252,8 +252,7 @@ def update_bn(loader, model, device=None):
|
||||
momenta = {}
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||
module.running_mean = torch.zeros_like(module.running_mean)
|
||||
module.running_var = torch.ones_like(module.running_var)
|
||||
module.reset_running_stats()
|
||||
momenta[module] = module.momentum
|
||||
|
||||
if not momenta:
|
||||
@ -263,7 +262,6 @@ def update_bn(loader, model, device=None):
|
||||
model.train()
|
||||
for module in momenta.keys():
|
||||
module.momentum = None
|
||||
module.num_batches_tracked *= 0
|
||||
|
||||
for input in loader:
|
||||
if isinstance(input, (list, tuple)):
|
||||
|
Reference in New Issue
Block a user