Optim package docstring fix (#129086)

Fix docstrings in various files in optim package. This is a last remaining fix for the issue #112593

The fix can be verified by running pydocstyle path-to-file --count

Fixes #112593

Related #128248

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129086
Approved by: https://github.com/janeyx99
This commit is contained in:
Sahdev Zala
2024-06-21 14:30:53 +00:00
committed by PyTorch MergeBot
parent b697808056
commit 9795dba1e0
8 changed files with 76 additions and 55 deletions

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Weight Averaging implementation."""
import itertools
import math
import warnings
@ -28,6 +29,8 @@ PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
def get_ema_multi_avg_fn(decay=0.999):
"""Get the function applying exponential moving average (EMA) across multiple params."""
@torch.no_grad()
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
# foreach lerp only handles float and complex
@ -43,6 +46,8 @@ def get_ema_multi_avg_fn(decay=0.999):
def get_swa_multi_avg_fn():
"""Get the function applying stochastic weight average (SWA) across multiple params."""
@torch.no_grad()
def swa_update(
averaged_param_list: PARAM_LIST,
@ -73,6 +78,8 @@ def get_swa_multi_avg_fn():
def get_ema_avg_fn(decay=0.999):
"""Get the function applying exponential moving average (EMA) across a single param."""
@torch.no_grad()
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
return decay * ema_param + (1 - decay) * current_param
@ -81,6 +88,8 @@ def get_ema_avg_fn(decay=0.999):
def get_swa_avg_fn():
"""Get the function applying stochastic weight average (SWA) across a single param."""
@torch.no_grad()
def swa_update(
averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int]
@ -91,8 +100,7 @@ def get_swa_avg_fn():
class AveragedModel(Module):
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and
Exponential Moving Average (EMA).
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
@ -189,6 +197,7 @@ class AveragedModel(Module):
.. _Polyak averaging:
https://paperswithcode.com/method/polyak-averaging
"""
n_averaged: Tensor
def __init__(
@ -200,7 +209,7 @@ class AveragedModel(Module):
Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]
] = None,
use_buffers=False,
):
): # noqa: D107
super().__init__()
assert (
avg_fn is None or multi_avg_fn is None
@ -216,9 +225,11 @@ class AveragedModel(Module):
self.use_buffers = use_buffers
def forward(self, *args, **kwargs):
"""Forward pass."""
return self.module(*args, **kwargs)
def update_parameters(self, model: Module):
"""Update model parameters."""
self_param = (
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers
@ -287,7 +298,7 @@ def update_bn(
model: Module,
device: Optional[Union[int, torch.device]] = None,
):
r"""Updates BatchNorm running_mean, running_var buffers in the model.
r"""Update BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in `loader` to estimate the activation
statistics for BatchNorm layers in the model.
@ -390,7 +401,7 @@ class SWALR(LRScheduler):
anneal_epochs=10,
anneal_strategy: Literal["cos", "linear"] = "cos",
last_epoch=-1,
):
): # noqa: D107
swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
group["swa_lr"] = swa_lr
@ -425,6 +436,7 @@ class SWALR(LRScheduler):
return (lr - alpha * swa_lr) / (1 - alpha)
def get_lr(self):
"""Get learning rate."""
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
# so we ignore the type error here. See `LRScheduler.step()` for more details.
if not self._get_lr_called_within_step: # type: ignore[attr-defined]