mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optim package docstring fix (#129086)
Fix docstrings in various files in optim package. This is a last remaining fix for the issue #112593 The fix can be verified by running pydocstyle path-to-file --count Fixes #112593 Related #128248 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129086 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
b697808056
commit
9795dba1e0
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for Stochastic Weight Averaging implementation."""
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
@ -28,6 +29,8 @@ PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
|
||||
|
||||
def get_ema_multi_avg_fn(decay=0.999):
|
||||
"""Get the function applying exponential moving average (EMA) across multiple params."""
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
|
||||
# foreach lerp only handles float and complex
|
||||
@ -43,6 +46,8 @@ def get_ema_multi_avg_fn(decay=0.999):
|
||||
|
||||
|
||||
def get_swa_multi_avg_fn():
|
||||
"""Get the function applying stochastic weight average (SWA) across multiple params."""
|
||||
|
||||
@torch.no_grad()
|
||||
def swa_update(
|
||||
averaged_param_list: PARAM_LIST,
|
||||
@ -73,6 +78,8 @@ def get_swa_multi_avg_fn():
|
||||
|
||||
|
||||
def get_ema_avg_fn(decay=0.999):
|
||||
"""Get the function applying exponential moving average (EMA) across a single param."""
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
|
||||
return decay * ema_param + (1 - decay) * current_param
|
||||
@ -81,6 +88,8 @@ def get_ema_avg_fn(decay=0.999):
|
||||
|
||||
|
||||
def get_swa_avg_fn():
|
||||
"""Get the function applying stochastic weight average (SWA) across a single param."""
|
||||
|
||||
@torch.no_grad()
|
||||
def swa_update(
|
||||
averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int]
|
||||
@ -91,8 +100,7 @@ def get_swa_avg_fn():
|
||||
|
||||
|
||||
class AveragedModel(Module):
|
||||
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and
|
||||
Exponential Moving Average (EMA).
|
||||
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
|
||||
|
||||
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
|
||||
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
|
||||
@ -189,6 +197,7 @@ class AveragedModel(Module):
|
||||
.. _Polyak averaging:
|
||||
https://paperswithcode.com/method/polyak-averaging
|
||||
"""
|
||||
|
||||
n_averaged: Tensor
|
||||
|
||||
def __init__(
|
||||
@ -200,7 +209,7 @@ class AveragedModel(Module):
|
||||
Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]
|
||||
] = None,
|
||||
use_buffers=False,
|
||||
):
|
||||
): # noqa: D107
|
||||
super().__init__()
|
||||
assert (
|
||||
avg_fn is None or multi_avg_fn is None
|
||||
@ -216,9 +225,11 @@ class AveragedModel(Module):
|
||||
self.use_buffers = use_buffers
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Forward pass."""
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def update_parameters(self, model: Module):
|
||||
"""Update model parameters."""
|
||||
self_param = (
|
||||
itertools.chain(self.module.parameters(), self.module.buffers())
|
||||
if self.use_buffers
|
||||
@ -287,7 +298,7 @@ def update_bn(
|
||||
model: Module,
|
||||
device: Optional[Union[int, torch.device]] = None,
|
||||
):
|
||||
r"""Updates BatchNorm running_mean, running_var buffers in the model.
|
||||
r"""Update BatchNorm running_mean, running_var buffers in the model.
|
||||
|
||||
It performs one pass over data in `loader` to estimate the activation
|
||||
statistics for BatchNorm layers in the model.
|
||||
@ -390,7 +401,7 @@ class SWALR(LRScheduler):
|
||||
anneal_epochs=10,
|
||||
anneal_strategy: Literal["cos", "linear"] = "cos",
|
||||
last_epoch=-1,
|
||||
):
|
||||
): # noqa: D107
|
||||
swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
|
||||
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
|
||||
group["swa_lr"] = swa_lr
|
||||
@ -425,6 +436,7 @@ class SWALR(LRScheduler):
|
||||
return (lr - alpha * swa_lr) / (1 - alpha)
|
||||
|
||||
def get_lr(self):
|
||||
"""Get learning rate."""
|
||||
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
|
||||
# so we ignore the type error here. See `LRScheduler.step()` for more details.
|
||||
if not self._get_lr_called_within_step: # type: ignore[attr-defined]
|
||||
|
Reference in New Issue
Block a user