mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Improved error lr last epoch (#162368)
Fixes #160626 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162368 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
955e195c7d
commit
cfc539fe15
@ -369,6 +369,16 @@ class TestLRScheduler(TestCase):
|
||||
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
||||
self._test_get_last_lr(scheduler, targets, epochs)
|
||||
|
||||
def test_raise_error_when_last_epoch_is_greater_than_0_and_initial_lr_is_not_specified(
|
||||
self,
|
||||
):
|
||||
optimizer = SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1)
|
||||
with self.assertRaisesRegex(
|
||||
KeyError,
|
||||
r"param \'initial_lr\' is not specified in param_groups\[0\] when resuming scheduler with last_epoch >= 0",
|
||||
):
|
||||
StepLR(optimizer, step_size=3, gamma=0.1, last_epoch=1)
|
||||
|
||||
def test_multi_step_lr(self):
|
||||
# lr = 0.05 if epoch < 2
|
||||
# lr = 0.005 if 2 <= epoch < 5
|
||||
|
@ -106,8 +106,10 @@ class LRScheduler:
|
||||
for i, group in enumerate(optimizer.param_groups):
|
||||
if "initial_lr" not in group:
|
||||
raise KeyError(
|
||||
"param 'initial_lr' is not specified "
|
||||
f"in param_groups[{i}] when resuming an optimizer"
|
||||
f"param 'initial_lr' is not specified in param_groups[{i}] when resuming scheduler with last_epoch >= 0.\n"
|
||||
"This typically happens when:\n"
|
||||
"1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
|
||||
"2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
|
||||
)
|
||||
self.base_lrs: list[float] = [
|
||||
group["initial_lr"] for group in optimizer.param_groups
|
||||
|
Reference in New Issue
Block a user