mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add missing interfaces of torch.optim.swa_utils
(#117036)
Add type hints for the function/class interfaces that appear in torch/optim/swa_utils.py but are missing in torch/optim/swa_utils.pyi. - get_ema_multi_avg_fn - get_swa_multi_avg_fn - get_ema_avg_fn - get_swa_avg_fn - AveragedModel.__init__(multi_avg_fn) - SWALR.get_lr Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/117036 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c0a380bdf
commit
ab647bd325
@ -1,12 +1,15 @@
|
||||
import itertools
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, Dict, cast
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
|
||||
from .optimizer import Optimizer
|
||||
|
||||
__all__ = [
|
||||
'AveragedModel',
|
||||
@ -18,12 +21,14 @@ __all__ = [
|
||||
'get_swa_avg_fn'
|
||||
]
|
||||
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, TensorListList, Indices
|
||||
|
||||
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
|
||||
|
||||
def get_ema_multi_avg_fn(decay=0.999):
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param_list, current_param_list, _):
|
||||
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
|
||||
# foreach lerp only handles float and complex
|
||||
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]):
|
||||
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
|
||||
@ -36,20 +41,23 @@ def get_ema_multi_avg_fn(decay=0.999):
|
||||
|
||||
def get_swa_multi_avg_fn():
|
||||
@torch.no_grad()
|
||||
def swa_update(averaged_param_list, current_param_list, num_averaged):
|
||||
def swa_update(averaged_param_list: PARAM_LIST, current_param_list: PARAM_LIST, num_averaged: Union[Tensor, int]):
|
||||
# foreach lerp only handles float and complex
|
||||
if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(averaged_param_list[0]):
|
||||
torch._foreach_lerp_(averaged_param_list, current_param_list, 1 / (num_averaged + 1))
|
||||
else:
|
||||
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
|
||||
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
|
||||
if isinstance(num_averaged, Tensor):
|
||||
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
|
||||
else:
|
||||
torch._foreach_add_(averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1))
|
||||
|
||||
return swa_update
|
||||
|
||||
|
||||
def get_ema_avg_fn(decay=0.999):
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param, current_param, num_averaged):
|
||||
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
|
||||
return decay * ema_param + (1 - decay) * current_param
|
||||
|
||||
return ema_update
|
||||
@ -57,7 +65,7 @@ def get_ema_avg_fn(decay=0.999):
|
||||
|
||||
def get_swa_avg_fn():
|
||||
@torch.no_grad()
|
||||
def swa_update(averaged_param, current_param, num_averaged):
|
||||
def swa_update(averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int]):
|
||||
return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
|
||||
|
||||
return swa_update
|
||||
@ -162,7 +170,17 @@ class AveragedModel(Module):
|
||||
.. _Polyak averaging:
|
||||
https://paperswithcode.com/method/polyak-averaging
|
||||
"""
|
||||
def __init__(self, model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Module,
|
||||
device: Optional[Union[int, torch.device]] = None,
|
||||
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]],
|
||||
Tensor]] = None,
|
||||
multi_avg_fn: Optional[Callable[
|
||||
[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]] = None,
|
||||
use_buffers=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert avg_fn is None or multi_avg_fn is None, 'Only one of avg_fn and multi_avg_fn should be provided'
|
||||
self.module = deepcopy(model)
|
||||
@ -177,7 +195,7 @@ class AveragedModel(Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def update_parameters(self, model):
|
||||
def update_parameters(self, model: Module):
|
||||
self_param = (
|
||||
itertools.chain(self.module.parameters(), self.module.buffers())
|
||||
if self.use_buffers else self.parameters()
|
||||
@ -197,7 +215,11 @@ class AveragedModel(Module):
|
||||
|
||||
if self.n_averaged > 0:
|
||||
if self.multi_avg_fn is not None or self.avg_fn is None:
|
||||
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
|
||||
grouped_tensors = _group_tensors_by_device_and_dtype(
|
||||
cast(TensorListList, [self_param_detached, model_param_detached]))
|
||||
grouped_tensors = cast(
|
||||
Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], Indices]],
|
||||
grouped_tensors)
|
||||
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
|
||||
if self.multi_avg_fn:
|
||||
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
|
||||
@ -223,7 +245,7 @@ class AveragedModel(Module):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_bn(loader, model, device=None):
|
||||
def update_bn(loader: Iterable[Any], model: Module, device: Optional[Union[int, torch.device]] = None):
|
||||
r"""Updates BatchNorm running_mean, running_var buffers in the model.
|
||||
|
||||
It performs one pass over data in `loader` to estimate the activation
|
||||
@ -319,7 +341,7 @@ class SWALR(LRScheduler):
|
||||
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
|
||||
https://arxiv.org/abs/1803.05407
|
||||
"""
|
||||
def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
|
||||
def __init__(self, optimizer: Optimizer, swa_lr: float, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
|
||||
swa_lrs = self._format_param(optimizer, swa_lr)
|
||||
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
|
||||
group['swa_lr'] = swa_lr
|
||||
@ -361,10 +383,13 @@ class SWALR(LRScheduler):
|
||||
return (lr - alpha * swa_lr) / (1 - alpha)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
|
||||
# so we ignore the type error here. See `LRScheduler.step()` for more details.
|
||||
if not self._get_lr_called_within_step: # type: ignore[attr-defined]
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.", UserWarning)
|
||||
step = self._step_count - 1
|
||||
# Set in `LRScheduler._initial_step()`
|
||||
step = self._step_count - 1 # type: ignore[attr-defined]
|
||||
if self.anneal_epochs == 0:
|
||||
step = max(1, step)
|
||||
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
||||
|
Reference in New Issue
Block a user