mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Allow zero annealing epochs (#47579)
Summary: Fixes https://github.com/pytorch/pytorch/issues/47578. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47579 Reviewed By: H-Huang Differential Revision: D25429403 Pulled By: vincentqb fbshipit-source-id: c42fbcd71b46e07c672a1e9661468848ac16de38
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4431731c68
commit
09173ae65e
@ -219,8 +219,8 @@ class SWALR(_LRScheduler):
|
|||||||
self.anneal_func = self._cosine_anneal
|
self.anneal_func = self._cosine_anneal
|
||||||
elif anneal_strategy == 'linear':
|
elif anneal_strategy == 'linear':
|
||||||
self.anneal_func = self._linear_anneal
|
self.anneal_func = self._linear_anneal
|
||||||
if not isinstance(anneal_epochs, int) or anneal_epochs < 1:
|
if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
|
||||||
raise ValueError("anneal_epochs must be a positive integer, got {}".format(
|
raise ValueError("anneal_epochs must be equal or greater than 0, got {}".format(
|
||||||
anneal_epochs))
|
anneal_epochs))
|
||||||
self.anneal_epochs = anneal_epochs
|
self.anneal_epochs = anneal_epochs
|
||||||
|
|
||||||
@ -257,11 +257,13 @@ class SWALR(_LRScheduler):
|
|||||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||||
"please use `get_last_lr()`.", UserWarning)
|
"please use `get_last_lr()`.", UserWarning)
|
||||||
step = self._step_count - 1
|
step = self._step_count - 1
|
||||||
prev_t = max(0, min(1, (step - 1) / self.anneal_epochs))
|
if self.anneal_epochs == 0:
|
||||||
|
step = max(1, step)
|
||||||
|
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
||||||
prev_alpha = self.anneal_func(prev_t)
|
prev_alpha = self.anneal_func(prev_t)
|
||||||
prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha)
|
prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha)
|
||||||
for group in self.optimizer.param_groups]
|
for group in self.optimizer.param_groups]
|
||||||
t = max(0, min(1, step / self.anneal_epochs))
|
t = max(0, min(1, step / max(1, self.anneal_epochs)))
|
||||||
alpha = self.anneal_func(t)
|
alpha = self.anneal_func(t)
|
||||||
return [group['swa_lr'] * alpha + lr * (1 - alpha)
|
return [group['swa_lr'] * alpha + lr * (1 - alpha)
|
||||||
for group, lr in zip(self.optimizer.param_groups, prev_lrs)]
|
for group, lr in zip(self.optimizer.param_groups, prev_lrs)]
|
||||||
|
Reference in New Issue
Block a user