Allow zero annealing epochs (#47579)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/47578.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47579

Reviewed By: H-Huang

Differential Revision: D25429403

Pulled By: vincentqb

fbshipit-source-id: c42fbcd71b46e07c672a1e9661468848ac16de38
This commit is contained in:
Daniil Osokin
2020-12-16 14:02:57 -08:00
committed by Facebook GitHub Bot
parent 4431731c68
commit 09173ae65e

View File

@ -219,8 +219,8 @@ class SWALR(_LRScheduler):
self.anneal_func = self._cosine_anneal
elif anneal_strategy == 'linear':
self.anneal_func = self._linear_anneal
if not isinstance(anneal_epochs, int) or anneal_epochs < 1:
raise ValueError("anneal_epochs must be a positive integer, got {}".format(
if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
raise ValueError("anneal_epochs must be equal or greater than 0, got {}".format(
anneal_epochs))
self.anneal_epochs = anneal_epochs
@ -257,11 +257,13 @@ class SWALR(_LRScheduler):
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
step = self._step_count - 1
prev_t = max(0, min(1, (step - 1) / self.anneal_epochs))
if self.anneal_epochs == 0:
step = max(1, step)
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
prev_alpha = self.anneal_func(prev_t)
prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha)
for group in self.optimizer.param_groups]
t = max(0, min(1, step / self.anneal_epochs))
t = max(0, min(1, step / max(1, self.anneal_epochs)))
alpha = self.anneal_func(t)
return [group['swa_lr'] * alpha + lr * (1 - alpha)
for group, lr in zip(self.optimizer.param_groups, prev_lrs)]