mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix lr_scheduler
unexpectedly calls step()
when init argument last_epoch is larger than -1 (#149312)
Fixes #102261 ## Changes - Use flag `_is_initial` to replace `self.last_epoch == 0` condition to judge whether `lr` should be initial value - Add test for `ExponentialLR` checkpoint usecase ## Test Result ```python pytest -s test/optim/test_lrscheduler.py -vv ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/149312 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
423fc671e9
commit
d7a83ab67b
@ -82,6 +82,7 @@ class LRScheduler:
|
||||
r"""Adjusts the learning rate during optimization."""
|
||||
|
||||
_get_lr_called_within_step: bool = False
|
||||
_is_initial: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -141,7 +142,8 @@ class LRScheduler:
|
||||
def _initial_step(self) -> None:
|
||||
"""Initialize step counts and perform a step."""
|
||||
self._step_count = 0
|
||||
self.step()
|
||||
with _initial_mode(self):
|
||||
self.step()
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""Return the state of the scheduler as a :class:`dict`.
|
||||
@ -195,6 +197,7 @@ class LRScheduler:
|
||||
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self._step_count += 1
|
||||
|
||||
with _enable_get_lr_call(self):
|
||||
@ -248,6 +251,17 @@ class _enable_get_lr_call:
|
||||
self.o._get_lr_called_within_step = False
|
||||
|
||||
|
||||
class _initial_mode:
|
||||
def __init__(self, o: LRScheduler):
|
||||
self.o = o
|
||||
|
||||
def __enter__(self):
|
||||
self.o._is_initial = True
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.o._is_initial = False
|
||||
|
||||
|
||||
class LambdaLR(LRScheduler):
|
||||
"""Sets the initial learning rate.
|
||||
|
||||
@ -450,7 +464,7 @@ class MultiplicativeLR(LRScheduler):
|
||||
"""Compute the learning rate of each parameter group."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch > 0:
|
||||
if not self._is_initial:
|
||||
return [
|
||||
group["lr"] * lmbda(self.last_epoch)
|
||||
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
|
||||
@ -715,7 +729,7 @@ class LinearLR(LRScheduler):
|
||||
group["lr"] * self.start_factor for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
if self.last_epoch > self.total_iters:
|
||||
if self._is_initial or self.last_epoch > self.total_iters:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
return [
|
||||
@ -779,7 +793,9 @@ class ExponentialLR(LRScheduler):
|
||||
"""Compute the learning rate of each parameter group."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch == 0:
|
||||
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
|
||||
# to update the lr one more step ahead of itself.
|
||||
if self._is_initial:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
|
||||
|
||||
@ -979,7 +995,7 @@ class PolynomialLR(LRScheduler):
|
||||
"""Compute the learning rate."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
|
||||
if self._is_initial or self.last_epoch > self.total_iters:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
decay_factor = (
|
||||
@ -1065,7 +1081,7 @@ class CosineAnnealingLR(LRScheduler):
|
||||
"""Retrieve the learning rate of each parameter group."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch == 0:
|
||||
if self._is_initial:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
elif self._step_count == 1 and self.last_epoch > 0:
|
||||
return [
|
||||
|
Reference in New Issue
Block a user