mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/**
to ruff format
(#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
3e38feb05f
commit
596b418391
@ -347,9 +347,9 @@ def _single_tensor_adafactor(
|
||||
maximize: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
assert (
|
||||
grad_scale is None and found_inf is None
|
||||
), "Grad scaling should occur outside of optimizer.step()"
|
||||
assert grad_scale is None and found_inf is None, (
|
||||
"Grad scaling should occur outside of optimizer.step()"
|
||||
)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||
@ -381,9 +381,9 @@ def _single_tensor_adafactor(
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
|
||||
if grad.dim() > 1:
|
||||
assert (
|
||||
row_var is not None and col_var is not None
|
||||
), "row_var and col_var should be defined when grad is multidimensional"
|
||||
assert row_var is not None and col_var is not None, (
|
||||
"row_var and col_var should be defined when grad is multidimensional"
|
||||
)
|
||||
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
|
||||
row_mean = (
|
||||
torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
|
||||
@ -397,9 +397,9 @@ def _single_tensor_adafactor(
|
||||
var_estimate = row_var @ col_var
|
||||
var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
|
||||
else:
|
||||
assert (
|
||||
variance is not None
|
||||
), "variance should be defined when grad is a vector"
|
||||
assert variance is not None, (
|
||||
"variance should be defined when grad is a vector"
|
||||
)
|
||||
grad_squared = grad * grad
|
||||
variance.lerp_(grad_squared, one_minus_beta2_t)
|
||||
# avoid writing into variance during update
|
||||
@ -472,9 +472,9 @@ def _multi_tensor_adafactor(
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
assert (
|
||||
grad_scale is None and found_inf is None
|
||||
), "Grad scaling should occur outside of optimizer.step()"
|
||||
assert grad_scale is None and found_inf is None, (
|
||||
"Grad scaling should occur outside of optimizer.step()"
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
@ -495,9 +495,9 @@ def _multi_tensor_adafactor(
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
if eps1 is None:
|
||||
assert (
|
||||
dtype is not None
|
||||
), "dtype is needed to compute eps1 when eps1 is unset"
|
||||
assert dtype is not None, (
|
||||
"dtype is needed to compute eps1 when eps1 is unset"
|
||||
)
|
||||
eps1 = torch.finfo(dtype).eps
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -537,9 +537,9 @@ def _multi_tensor_adafactor(
|
||||
if is_multidim:
|
||||
device_row_vars = cast(list[Tensor], device_row_vars_)
|
||||
device_col_vars = cast(list[Tensor], device_col_vars_)
|
||||
assert (
|
||||
device_row_vars[0] is not None and device_col_vars[0] is not None
|
||||
), "row_var and col_var should be defined when grad is multidimensional"
|
||||
assert device_row_vars[0] is not None and device_col_vars[0] is not None, (
|
||||
"row_var and col_var should be defined when grad is multidimensional"
|
||||
)
|
||||
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
|
||||
row_means = [
|
||||
torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
|
||||
@ -570,9 +570,9 @@ def _multi_tensor_adafactor(
|
||||
del row_var_means
|
||||
else:
|
||||
device_variances = cast(list[Tensor], device_variances_)
|
||||
assert (
|
||||
device_variances[0] is not None
|
||||
), "variance should be defined when grad is a vector"
|
||||
assert device_variances[0] is not None, (
|
||||
"variance should be defined when grad is a vector"
|
||||
)
|
||||
|
||||
grads_squared = torch._foreach_mul(device_grads, device_grads)
|
||||
torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Functional interface."""
|
||||
|
||||
import math
|
||||
|
||||
from torch import Tensor
|
||||
|
@ -5,6 +5,7 @@ Most commonly used methods are already supported, and the interface is general
|
||||
enough, so that more sophisticated ones can be also easily integrated in the
|
||||
future.
|
||||
"""
|
||||
|
||||
from functools import partialmethod
|
||||
|
||||
from torch import optim
|
||||
|
@ -267,7 +267,9 @@ def _single_tensor_adadelta(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
lr = _to_scalar(lr)
|
||||
@ -326,7 +328,9 @@ def _multi_tensor_adadelta(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
@ -398,7 +398,9 @@ def _single_tensor_adam(
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
@ -433,7 +435,9 @@ def _single_tensor_adam(
|
||||
# cast to workaround https://github.com/pytorch/pytorch/issues/140601
|
||||
key = (device, dtype)
|
||||
if key not in beta1_dict:
|
||||
beta1_dict[key] = beta1.to(device=device, dtype=dtype, non_blocking=True) # type: ignore[union-attr]
|
||||
beta1_dict[key] = beta1.to( # type: ignore[union-attr]
|
||||
device=device, dtype=dtype, non_blocking=True
|
||||
)
|
||||
|
||||
device_beta1: Union[float, Tensor] = beta1_dict[key]
|
||||
else:
|
||||
@ -593,7 +597,9 @@ def _multi_tensor_adam(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
@ -769,7 +775,10 @@ def _multi_tensor_adam(
|
||||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
|
||||
torch._foreach_add_(exp_avg_sq_sqrt, eps)
|
||||
torch._foreach_addcdiv_(
|
||||
device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type]
|
||||
device_params,
|
||||
device_exp_avgs,
|
||||
exp_avg_sq_sqrt,
|
||||
step_size, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
@ -256,7 +256,9 @@ def _single_tensor_adamax(
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
@ -331,7 +333,9 @@ def _multi_tensor_adamax(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -305,7 +305,9 @@ def _multi_tensor_asgd(
|
||||
p.device.type == mu.device.type == eta.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, mu, eta, step in zip(params, mus, etas, state_steps)
|
||||
), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Learning Rate Scheduler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
@ -280,7 +281,7 @@ class LambdaLR(LRScheduler):
|
||||
>>> # Assuming optimizer has two groups.
|
||||
>>> num_epochs = 100
|
||||
>>> lambda1 = lambda epoch: epoch // 30
|
||||
>>> lambda2 = lambda epoch: 0.95 ** epoch
|
||||
>>> lambda2 = lambda epoch: 0.95**epoch
|
||||
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
|
||||
>>> for epoch in range(num_epochs):
|
||||
>>> train(...)
|
||||
@ -548,7 +549,7 @@ class MultiStepLR(LRScheduler):
|
||||
>>> # lr = 0.05 if epoch < 30
|
||||
>>> # lr = 0.005 if 30 <= epoch < 80
|
||||
>>> # lr = 0.0005 if epoch >= 80
|
||||
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
|
||||
>>> scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
@ -827,7 +828,11 @@ class SequentialLR(LRScheduler):
|
||||
>>> # lr = 0.0405 if epoch == 22
|
||||
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
|
||||
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
|
||||
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[20])
|
||||
>>> scheduler = SequentialLR(
|
||||
... optimizer,
|
||||
... schedulers=[scheduler1, scheduler2],
|
||||
... milestones=[20],
|
||||
... )
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
@ -1271,11 +1276,11 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
|
||||
>>> scheduler = ReduceLROnPlateau(optimizer, "min")
|
||||
>>> for epoch in range(10):
|
||||
>>> train(...)
|
||||
>>> val_loss = validate(...)
|
||||
>>> # Note that step should be called after validate()
|
||||
>>> # Note that step should be called after validate()
|
||||
>>> scheduler.step(val_loss)
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png
|
||||
@ -1502,7 +1507,12 @@ class CyclicLR(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=10)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(
|
||||
... optimizer,
|
||||
... base_lr=0.01,
|
||||
... max_lr=0.1,
|
||||
... step_size_up=10,
|
||||
... )
|
||||
>>> data_loader = torch.utils.data.DataLoader(...)
|
||||
>>> for epoch in range(10):
|
||||
>>> for batch in data_loader:
|
||||
@ -1729,7 +1739,9 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
... optimizer, T_0=20
|
||||
... )
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
@ -1800,7 +1812,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
>>> for epoch in range(20):
|
||||
>>> scheduler.step()
|
||||
>>> scheduler.step(26)
|
||||
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
|
||||
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
|
||||
"""
|
||||
if epoch is None and self.last_epoch < 0:
|
||||
epoch = 0
|
||||
@ -1936,7 +1948,9 @@ class OneCycleLR(LRScheduler):
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> data_loader = torch.utils.data.DataLoader(...)
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
|
||||
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
|
||||
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
|
||||
... )
|
||||
>>> for epoch in range(10):
|
||||
>>> for batch in data_loader:
|
||||
>>> train_batch(...)
|
||||
@ -2141,8 +2155,6 @@ class OneCycleLR(LRScheduler):
|
||||
if self.use_beta1:
|
||||
group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
group[
|
||||
"momentum"
|
||||
] = computed_momentum # type: ignore[possibly-undefined]
|
||||
group["momentum"] = computed_momentum # type: ignore[possibly-undefined]
|
||||
|
||||
return lrs
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the NAdam algorithm."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -408,7 +409,11 @@ def _multi_tensor_nadam(
|
||||
p.device.type == mp.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, mp, step in zip(params, mu_products, state_steps)
|
||||
), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
"If capturable=True, "
|
||||
"params, mu_products, and state_steps must be on supported devices: "
|
||||
f"{capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
@ -576,10 +581,16 @@ def _multi_tensor_nadam(
|
||||
)
|
||||
|
||||
torch._foreach_addcdiv_(
|
||||
grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type]
|
||||
grouped_params,
|
||||
grouped_grads,
|
||||
exp_avg_sq_sqrt,
|
||||
step_size_grads, # type: ignore[arg-type]
|
||||
)
|
||||
torch._foreach_addcdiv_(
|
||||
grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type]
|
||||
grouped_params,
|
||||
grouped_exp_avgs,
|
||||
exp_avg_sq_sqrt,
|
||||
step_size_expavg, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Base optimizer."""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from collections import defaultdict, OrderedDict
|
||||
@ -103,7 +104,7 @@ def _stack_if_compiling(x):
|
||||
|
||||
|
||||
def _disable_dynamo_if_unsupported(
|
||||
single_tensor_fn: Optional[Callable[..., object]] = None
|
||||
single_tensor_fn: Optional[Callable[..., object]] = None,
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
# workaround for torchscript BC
|
||||
# it requires all called functions to be in the
|
||||
@ -349,15 +350,24 @@ class Optimizer:
|
||||
options (used when a parameter group doesn't specify them).
|
||||
"""
|
||||
|
||||
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc]
|
||||
OptimizerPreHook: TypeAlias = Callable[
|
||||
[Self, Args, Kwargs], # type: ignore[misc]
|
||||
Optional[tuple[Args, Kwargs]],
|
||||
]
|
||||
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
|
||||
|
||||
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
||||
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: (
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_pre_hooks: (
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_post_hooks: (
|
||||
'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
)
|
||||
|
||||
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
|
||||
torch._C._log_api_usage_once("python.optimizer")
|
||||
@ -847,7 +857,9 @@ class Optimizer:
|
||||
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
|
||||
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
|
||||
if prepend:
|
||||
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
|
||||
self._optimizer_load_state_dict_post_hooks.move_to_end(
|
||||
handle.id, last=False
|
||||
) # type: ignore[attr-defined]
|
||||
return handle
|
||||
|
||||
@torch._disable_dynamo
|
||||
@ -877,12 +889,25 @@ class Optimizer:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> model = torch.nn.Linear(10, 10)
|
||||
>>> optim = torch.optim.SGD(model.parameters(), lr=3e-4)
|
||||
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.1, end_factor=1, total_iters=20)
|
||||
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=80, eta_min=3e-5)
|
||||
>>> lr = torch.optim.lr_scheduler.SequentialLR(optim, schedulers=[scheduler1, scheduler2], milestones=[20])
|
||||
>>> lr.load_state_dict(torch.load('./save_seq.pt'))
|
||||
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(
|
||||
... optim,
|
||||
... start_factor=0.1,
|
||||
... end_factor=1,
|
||||
... total_iters=20,
|
||||
... )
|
||||
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
... optim,
|
||||
... T_max=80,
|
||||
... eta_min=3e-5,
|
||||
... )
|
||||
>>> lr = torch.optim.lr_scheduler.SequentialLR(
|
||||
... optim,
|
||||
... schedulers=[scheduler1, scheduler2],
|
||||
... milestones=[20],
|
||||
... )
|
||||
>>> lr.load_state_dict(torch.load("./save_seq.pt"))
|
||||
>>> # now load the optimizer checkpoint after loading the LRScheduler
|
||||
>>> optim.load_state_dict(torch.load('./save_optim.pt'))
|
||||
>>> optim.load_state_dict(torch.load("./save_optim.pt"))
|
||||
|
||||
"""
|
||||
# shallow copy, to be consistent with module API
|
||||
@ -933,7 +958,10 @@ class Optimizer:
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif isinstance(value, Iterable):
|
||||
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
|
||||
return type(value)(
|
||||
_cast(param, v, param_id=param_id, param_groups=param_groups)
|
||||
for v in value
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
return value
|
||||
|
||||
@ -1021,12 +1049,10 @@ class Optimizer:
|
||||
torch._foreach_zero_(grads)
|
||||
|
||||
@overload
|
||||
def step(self, closure: None = None) -> None:
|
||||
...
|
||||
def step(self, closure: None = None) -> None: ...
|
||||
|
||||
@overload
|
||||
def step(self, closure: Callable[[], float]) -> float:
|
||||
...
|
||||
def step(self, closure: Callable[[], float]) -> float: ...
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
r"""Perform a single optimization step to update parameter.
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the RAdam algorithm."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -285,7 +286,9 @@ def _single_tensor_radam(
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
if torch.is_complex(param):
|
||||
param = torch.view_as_real(param)
|
||||
@ -386,7 +389,9 @@ def _multi_tensor_radam(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the RMSprop algorithm."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -292,7 +293,9 @@ def _single_tensor_rmsprop(
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
grad = grads[i]
|
||||
grad = grad if not maximize else -grad
|
||||
@ -366,7 +369,9 @@ def _multi_tensor_rmsprop(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the Resilient backpropagation."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -248,7 +249,9 @@ def _single_tensor_rprop(
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
step += 1
|
||||
|
||||
@ -315,7 +318,9 @@ def _multi_tensor_rprop(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item]
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for Stochastic Gradient Descent optimizer."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -397,7 +398,8 @@ def _multi_tensor_sgd(
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item]
|
||||
[params, grads, momentum_buffer_list], # type: ignore[list-item]
|
||||
with_indices=True,
|
||||
)
|
||||
for (
|
||||
device_params_,
|
||||
@ -502,7 +504,8 @@ def _fused_sgd(
|
||||
for i, g in enumerate(grads):
|
||||
momentum_buffer_list[i] = torch.empty_like(g)
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item]
|
||||
[params, grads, momentum_buffer_list], # type: ignore[list-item]
|
||||
with_indices=False,
|
||||
)
|
||||
for (device, _), (
|
||||
(device_params_, device_grads_, device_momentum_buffer_list),
|
||||
|
@ -37,9 +37,9 @@ class SparseAdam(Optimizer):
|
||||
sparse_params = []
|
||||
complex_params = []
|
||||
for index, param_group in enumerate(self.param_groups):
|
||||
assert isinstance(
|
||||
param_group, dict
|
||||
), f"param_groups must be a list of dicts, but got {type(param_group)}"
|
||||
assert isinstance(param_group, dict), (
|
||||
f"param_groups must be a list of dicts, but got {type(param_group)}"
|
||||
)
|
||||
# given param group, convert given params to a list first before iterating
|
||||
for d_index, d_param in enumerate(param_group["params"]):
|
||||
if d_param.is_sparse:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for Stochastic Weight Averaging implementation."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
@ -225,9 +226,9 @@ class AveragedModel(Module):
|
||||
use_buffers=False,
|
||||
): # noqa: D107
|
||||
super().__init__()
|
||||
assert (
|
||||
avg_fn is None or multi_avg_fn is None
|
||||
), "Only one of avg_fn and multi_avg_fn should be provided"
|
||||
assert avg_fn is None or multi_avg_fn is None, (
|
||||
"Only one of avg_fn and multi_avg_fn should be provided"
|
||||
)
|
||||
self.module = deepcopy(model)
|
||||
if device is not None:
|
||||
self.module = self.module.to(device)
|
||||
@ -274,7 +275,9 @@ class AveragedModel(Module):
|
||||
) in grouped_tensors.items():
|
||||
if self.multi_avg_fn:
|
||||
self.multi_avg_fn(
|
||||
self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type]
|
||||
self_params, # type: ignore[arg-type]
|
||||
model_params, # type: ignore[arg-type]
|
||||
self.n_averaged.to(device),
|
||||
)
|
||||
elif (
|
||||
device is not None
|
||||
|
Reference in New Issue
Block a user