Fix lr_scheduler unexpectedly calls step() when init argument last_epoch is larger than -1 (#149312)

Fixes #102261

## Changes

- Use flag `_is_initial` to replace `self.last_epoch == 0` condition to judge whether `lr` should be initial value
- Add test for `ExponentialLR` checkpoint usecase

## Test Result

```python
pytest -s test/optim/test_lrscheduler.py  -vv
```

![image](https://github.com/user-attachments/assets/6fd32bcc-b4fb-4421-b891-620bd4900dc1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149312
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
zeshengzong
2025-05-22 08:42:33 +00:00
committed by PyTorch MergeBot
parent 423fc671e9
commit d7a83ab67b
2 changed files with 72 additions and 6 deletions

View File

@ -82,6 +82,7 @@ class LRScheduler:
r"""Adjusts the learning rate during optimization."""
_get_lr_called_within_step: bool = False
_is_initial: bool = False
def __init__(
self,
@ -141,7 +142,8 @@ class LRScheduler:
def _initial_step(self) -> None:
"""Initialize step counts and perform a step."""
self._step_count = 0
self.step()
with _initial_mode(self):
self.step()
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
@ -195,6 +197,7 @@ class LRScheduler:
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
self._step_count += 1
with _enable_get_lr_call(self):
@ -248,6 +251,17 @@ class _enable_get_lr_call:
self.o._get_lr_called_within_step = False
class _initial_mode:
def __init__(self, o: LRScheduler):
self.o = o
def __enter__(self):
self.o._is_initial = True
def __exit__(self, type, value, traceback):
self.o._is_initial = False
class LambdaLR(LRScheduler):
"""Sets the initial learning rate.
@ -450,7 +464,7 @@ class MultiplicativeLR(LRScheduler):
"""Compute the learning rate of each parameter group."""
_warn_get_lr_called_within_step(self)
if self.last_epoch > 0:
if not self._is_initial:
return [
group["lr"] * lmbda(self.last_epoch)
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
@ -715,7 +729,7 @@ class LinearLR(LRScheduler):
group["lr"] * self.start_factor for group in self.optimizer.param_groups
]
if self.last_epoch > self.total_iters:
if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
return [
@ -779,7 +793,9 @@ class ExponentialLR(LRScheduler):
"""Compute the learning rate of each parameter group."""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
# to update the lr one more step ahead of itself.
if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups]
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
@ -979,7 +995,7 @@ class PolynomialLR(LRScheduler):
"""Compute the learning rate."""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
decay_factor = (
@ -1065,7 +1081,7 @@ class CosineAnnealingLR(LRScheduler):
"""Retrieve the learning rate of each parameter group."""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups]
elif self._step_count == 1 and self.last_epoch > 0:
return [