Add missing interfaces of torch.optim.swa_utils (#117036)

Add type hints for the function/class interfaces that appear in torch/optim/swa_utils.py but are missing in torch/optim/swa_utils.pyi.

- get_ema_multi_avg_fn
- get_swa_multi_avg_fn
- get_ema_avg_fn
- get_swa_avg_fn
- AveragedModel.__init__(multi_avg_fn)
- SWALR.get_lr

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117036
Approved by: https://github.com/janeyx99
This commit is contained in:
David Chiu
2024-04-12 17:17:34 +00:00
committed by PyTorch MergeBot
parent 5c0a380bdf
commit ab647bd325
5 changed files with 53 additions and 53 deletions

View File

@ -1,12 +1,15 @@
import itertools
import math
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, Dict, cast
import warnings
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
from .optimizer import Optimizer
__all__ = [
'AveragedModel',
@ -18,12 +21,14 @@ __all__ = [
'get_swa_avg_fn'
]
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, TensorListList, Indices
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
def get_ema_multi_avg_fn(decay=0.999):
@torch.no_grad()
def ema_update(ema_param_list, current_param_list, _):
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
# foreach lerp only handles float and complex
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]):
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
@ -36,20 +41,23 @@ def get_ema_multi_avg_fn(decay=0.999):
def get_swa_multi_avg_fn():
@torch.no_grad()
def swa_update(averaged_param_list, current_param_list, num_averaged):
def swa_update(averaged_param_list: PARAM_LIST, current_param_list: PARAM_LIST, num_averaged: Union[Tensor, int]):
# foreach lerp only handles float and complex
if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(averaged_param_list[0]):
torch._foreach_lerp_(averaged_param_list, current_param_list, 1 / (num_averaged + 1))
else:
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
if isinstance(num_averaged, Tensor):
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
else:
torch._foreach_add_(averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1))
return swa_update
def get_ema_avg_fn(decay=0.999):
@torch.no_grad()
def ema_update(ema_param, current_param, num_averaged):
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
return decay * ema_param + (1 - decay) * current_param
return ema_update
@ -57,7 +65,7 @@ def get_ema_avg_fn(decay=0.999):
def get_swa_avg_fn():
@torch.no_grad()
def swa_update(averaged_param, current_param, num_averaged):
def swa_update(averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int]):
return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
return swa_update
@ -162,7 +170,17 @@ class AveragedModel(Module):
.. _Polyak averaging:
https://paperswithcode.com/method/polyak-averaging
"""
def __init__(self, model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False):
def __init__(
self,
model: Module,
device: Optional[Union[int, torch.device]] = None,
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]],
Tensor]] = None,
multi_avg_fn: Optional[Callable[
[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]] = None,
use_buffers=False,
):
super().__init__()
assert avg_fn is None or multi_avg_fn is None, 'Only one of avg_fn and multi_avg_fn should be provided'
self.module = deepcopy(model)
@ -177,7 +195,7 @@ class AveragedModel(Module):
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
def update_parameters(self, model: Module):
self_param = (
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers else self.parameters()
@ -197,7 +215,11 @@ class AveragedModel(Module):
if self.n_averaged > 0:
if self.multi_avg_fn is not None or self.avg_fn is None:
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
grouped_tensors = _group_tensors_by_device_and_dtype(
cast(TensorListList, [self_param_detached, model_param_detached]))
grouped_tensors = cast(
Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], Indices]],
grouped_tensors)
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
if self.multi_avg_fn:
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
@ -223,7 +245,7 @@ class AveragedModel(Module):
@torch.no_grad()
def update_bn(loader, model, device=None):
def update_bn(loader: Iterable[Any], model: Module, device: Optional[Union[int, torch.device]] = None):
r"""Updates BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in `loader` to estimate the activation
@ -319,7 +341,7 @@ class SWALR(LRScheduler):
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
https://arxiv.org/abs/1803.05407
"""
def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
def __init__(self, optimizer: Optimizer, swa_lr: float, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
swa_lrs = self._format_param(optimizer, swa_lr)
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
group['swa_lr'] = swa_lr
@ -361,10 +383,13 @@ class SWALR(LRScheduler):
return (lr - alpha * swa_lr) / (1 - alpha)
def get_lr(self):
if not self._get_lr_called_within_step:
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
# so we ignore the type error here. See `LRScheduler.step()` for more details.
if not self._get_lr_called_within_step: # type: ignore[attr-defined]
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
step = self._step_count - 1
# Set in `LRScheduler._initial_step()`
step = self._step_count - 1 # type: ignore[attr-defined]
if self.anneal_epochs == 0:
step = max(1, step)
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))