mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow zero annealing epochs (#47579)
Summary: Fixes https://github.com/pytorch/pytorch/issues/47578. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47579 Reviewed By: H-Huang Differential Revision: D25429403 Pulled By: vincentqb fbshipit-source-id: c42fbcd71b46e07c672a1e9661468848ac16de38
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4431731c68
commit
09173ae65e
@ -219,8 +219,8 @@ class SWALR(_LRScheduler):
|
||||
self.anneal_func = self._cosine_anneal
|
||||
elif anneal_strategy == 'linear':
|
||||
self.anneal_func = self._linear_anneal
|
||||
if not isinstance(anneal_epochs, int) or anneal_epochs < 1:
|
||||
raise ValueError("anneal_epochs must be a positive integer, got {}".format(
|
||||
if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
|
||||
raise ValueError("anneal_epochs must be equal or greater than 0, got {}".format(
|
||||
anneal_epochs))
|
||||
self.anneal_epochs = anneal_epochs
|
||||
|
||||
@ -257,11 +257,13 @@ class SWALR(_LRScheduler):
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.", UserWarning)
|
||||
step = self._step_count - 1
|
||||
prev_t = max(0, min(1, (step - 1) / self.anneal_epochs))
|
||||
if self.anneal_epochs == 0:
|
||||
step = max(1, step)
|
||||
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
||||
prev_alpha = self.anneal_func(prev_t)
|
||||
prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha)
|
||||
for group in self.optimizer.param_groups]
|
||||
t = max(0, min(1, step / self.anneal_epochs))
|
||||
t = max(0, min(1, step / max(1, self.anneal_epochs)))
|
||||
alpha = self.anneal_func(t)
|
||||
return [group['swa_lr'] * alpha + lr * (1 - alpha)
|
||||
for group, lr in zip(self.optimizer.param_groups, prev_lrs)]
|
||||
|
Reference in New Issue
Block a user