Files
pytorch/torch/optim/sgd.py
Maggie Moss 4ab847bbc7 Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
2025-10-06 16:14:36 +00:00

547 lines
20 KiB
Python

# mypy: allow-untyped-defs
r"""Implementation for Stochastic Gradient Descent optimizer."""
from typing import cast, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_foreach_doc,
_fused_doc,
_maximize_doc,
_params_doc,
_to_scalar,
_use_grad_for_differentiable,
DeviceDict,
Optimizer,
ParamsT,
)
__all__ = ["SGD", "sgd"]
class SGD(Optimizer): # noqa: D101
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
momentum: float = 0,
dampening: float = 0,
weight_decay: Union[float, Tensor] = 0,
nesterov: bool = False,
*,
maximize: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
fused: Optional[bool] = None,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = {
"lr": lr,
"momentum": momentum,
"dampening": dampening,
"weight_decay": weight_decay,
"nesterov": nesterov,
"maximize": maximize,
"foreach": foreach,
"differentiable": differentiable,
"fused": fused,
}
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
if fused:
self._step_supports_amp_scaling = True
self._need_device_dtype_check_for_fused = True
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("differentiable", False)
group.setdefault("fused", False)
def _init_group(self, group, params, grads, momentum_buffer_list):
has_sparse_grad = False
for p in group["params"]:
if p.grad is not None:
if group["fused"] and getattr(
self, "_need_device_dtype_check_for_fused", True
):
_device_dtype_check_for_fused(p)
self._need_device_dtype_check_for_fused = False
params.append(p)
grads.append(p.grad)
if p.grad.is_sparse:
has_sparse_grad = True
if group["momentum"] != 0:
state = self.state[p]
momentum_buffer_list.append(state.get("momentum_buffer"))
return has_sparse_grad
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params: list[Tensor] = []
grads: list[Tensor] = []
momentum_buffer_list: list[Optional[Tensor]] = []
has_sparse_grad = self._init_group(
group, params, grads, momentum_buffer_list
)
sgd(
params,
grads,
momentum_buffer_list,
weight_decay=group["weight_decay"],
momentum=group["momentum"],
lr=group["lr"],
dampening=group["dampening"],
nesterov=group["nesterov"],
maximize=group["maximize"],
has_sparse_grad=has_sparse_grad,
foreach=group["foreach"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
if group["momentum"] != 0:
# update momentum_buffers in state
for p, momentum_buffer in zip(params, momentum_buffer_list):
state = self.state[p]
state["momentum_buffer"] = momentum_buffer
return loss
SGD.__doc__ = (
r"""Implements stochastic gradient descent (optionally with momentum).
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
&\hspace{10mm}\textbf{if} \: t > 1 \\
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
&\hspace{10mm}\textbf{else} \\
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
&\hspace{10mm}\textbf{else} \\[-1.ex]
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
"""
+ rf"""
Args:
{_params_doc}
lr (float, Tensor, optional): learning rate (default: 1e-3)
momentum (float, optional): momentum factor (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
nesterov (bool, optional): enables Nesterov momentum. Only applicable
when momentum is non-zero. (default: False)
{_maximize_doc}
{_foreach_doc}
{_differentiable_doc}
{_fused_doc}
"""
+ r"""
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et al. and
other frameworks which employ an update of the form
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
p_{t+1} & = p_{t} - v_{t+1}.
\end{aligned}
The Nesterov version is analogously modified.
Moreover, the initial value of the momentum buffer is set to the
gradient value at the first step. This is in contrast to some other
frameworks that initialize it to all zeros. One notable side effect
of this decision is that the first momentum value will not be scaled
by dampening. Dampening will be applied starting at the second step.
"""
)
def sgd(
params: list[Tensor],
d_p_list: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
has_sparse_grad: bool = False,
foreach: Optional[bool] = None,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
):
r"""Functional API that performs SGD algorithm computation.
See :class:`~torch.optim.SGD` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if foreach is None and fused is None:
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
if not torch.jit.is_scripting():
fused, foreach = _default_to_fused_or_foreach(
params, differentiable=False, use_fused=False
)
else:
foreach = False
fused = False
if foreach is None:
foreach = False
if fused is None:
fused = False
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_sgd
elif fused and not torch.jit.is_scripting():
func = _fused_sgd
else:
func = _single_tensor_sgd
func(
params,
d_p_list,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
has_sparse_grad=has_sparse_grad,
maximize=maximize,
grad_scale=grad_scale,
found_inf=found_inf,
)
def _single_tensor_sgd(
params: list[Tensor],
grads: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
):
assert grad_scale is None and found_inf is None
if not torch.jit.is_scripting():
lr = _to_scalar(lr)
# pyrefly: ignore # bad-assignment
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
if weight_decay != 0:
# Nested if is necessary to bypass jitscript rules
if isinstance(weight_decay, Tensor):
if weight_decay.requires_grad:
# usually this is the differentiable path, which is why the param.clone() is needed
grad = grad.addcmul_(param.clone(), weight_decay)
else:
# pyrefly: ignore # bad-argument-type
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = grad.detach().clone()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
# Nested if is necessary to bypass jitscript rules
if isinstance(lr, Tensor):
if lr.requires_grad:
param.addcmul_(grad, lr, value=-1)
else:
# pyrefly: ignore # bad-argument-type
param.add_(grad, alpha=-lr)
else:
param.add_(grad, alpha=-lr)
def _multi_tensor_sgd(
params: list[Tensor],
grads: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
):
assert grad_scale is None and found_inf is None
if len(params) == 0:
return
lr = _to_scalar(lr)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], # type: ignore[list-item]
with_indices=True,
)
for (
device_params_,
device_grads_,
device_momentum_buffer_list,
), indices in grouped_tensors.values():
device_params: list[Tensor] = cast(list[Tensor], device_params_)
device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
device_has_sparse_grad = has_sparse_grad and any(
grad.is_sparse for grad in device_grads
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0:
# Reuse the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if momentum != 0:
bufs: list[Tensor] = []
all_states_with_momentum_buffer = True
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
all_states_with_momentum_buffer = False
break
else:
# pyrefly: ignore # index-error
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
if all_states_with_momentum_buffer:
torch._foreach_mul_(bufs, momentum)
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
else:
bufs = []
# pyrefly: ignore # bad-assignment
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
indices[i]
] = device_grads[i].detach().clone()
else:
# pyrefly: ignore # index-error
buf = cast(Tensor, device_momentum_buffer_list[i])
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
bufs.append(buf)
if nesterov:
torch._foreach_add_(device_grads, bufs, alpha=momentum)
else:
device_grads = bufs
if not device_has_sparse_grad:
# handle internal item() call if lr is a tensor
if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling():
grads_x_lr = torch._foreach_mul(device_grads, -lr)
torch._foreach_add_(device_params, grads_x_lr)
else:
torch._foreach_add_(device_params, device_grads, alpha=-lr)
else:
# foreach APIs don't support sparse
for i in range(len(device_params)):
device_params[i].add_(device_grads[i], alpha=-lr)
def _fused_sgd(
params: list[Tensor],
grads: list[Tensor],
momentum_buffer_list: list[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
) -> None:
if not params:
return
if has_sparse_grad:
raise RuntimeError("`_fused_sgd` does not support sparse gradients")
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
found_inf_dict: DeviceDict = (
{found_inf.device: found_inf} if found_inf is not None else {}
)
no_momentum_buffer = momentum == 0
is_first_step = (
all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
)
if is_first_step:
for i, g in enumerate(grads):
momentum_buffer_list[i] = torch.empty_like(g)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], # type: ignore[list-item]
with_indices=False,
)
for (device, _), (
(device_params_, device_grads_, device_momentum_buffer_list),
_,
) in grouped_tensors.items():
device_params: list[Tensor] = cast(list[Tensor], device_params_)
device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
device_grad_scale = grad_scale_dict.setdefault(
device, grad_scale.to(device)
)
if found_inf_dict is not None and found_inf is not None:
device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
torch._fused_sgd_(
device_params,
device_grads,
[]
if no_momentum_buffer
else cast(list[Tensor], device_momentum_buffer_list),
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
maximize=maximize,
is_first_step=is_first_step,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)