Improved error lr last epoch (#162368)

Fixes #160626

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162368
Approved by: https://github.com/janeyx99
This commit is contained in:
Kushagra Rastogi
2025-09-15 23:33:10 +00:00
committed by PyTorch MergeBot
parent 955e195c7d
commit cfc539fe15
2 changed files with 14 additions and 2 deletions

View File

@ -369,6 +369,16 @@ class TestLRScheduler(TestCase):
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
self._test_get_last_lr(scheduler, targets, epochs)
def test_raise_error_when_last_epoch_is_greater_than_0_and_initial_lr_is_not_specified(
self,
):
optimizer = SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1)
with self.assertRaisesRegex(
KeyError,
r"param \'initial_lr\' is not specified in param_groups\[0\] when resuming scheduler with last_epoch >= 0",
):
StepLR(optimizer, step_size=3, gamma=0.1, last_epoch=1)
def test_multi_step_lr(self):
# lr = 0.05 if epoch < 2
# lr = 0.005 if 2 <= epoch < 5

View File

@ -106,8 +106,10 @@ class LRScheduler:
for i, group in enumerate(optimizer.param_groups):
if "initial_lr" not in group:
raise KeyError(
"param 'initial_lr' is not specified "
f"in param_groups[{i}] when resuming an optimizer"
f"param 'initial_lr' is not specified in param_groups[{i}] when resuming scheduler with last_epoch >= 0.\n"
"This typically happens when:\n"
"1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
"2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
)
self.base_lrs: list[float] = [
group["initial_lr"] for group in optimizer.param_groups