mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[optim] prevent problematic tensor aliasing in lr_scheduler (#163098)
Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation. Supersedes #162360 Fixes #162359 Fixes #163093 While putting #162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR. There are several bugs resembling the one fixed in #162360. I added a helper to fix these: ```python def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): """Set param_group[key] to val without aliasing or assignment when they're both tensors. Raises a KeyError if param_group[key] does not exist. """ if isinstance(param_group[key], Tensor): param_group[key].fill_(_to_scalar(val)) else: param_group[key] = val ``` And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue: ```python for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr ``` But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163098 Approved by: https://github.com/janeyx99
This commit is contained in:
@ -710,6 +710,15 @@ class TestLRScheduler(TestCase):
|
||||
scheduler.get_last_lr(), [0.5 for param_group in self.opt.param_groups]
|
||||
)
|
||||
|
||||
def test_reduce_lr_on_plateau_preserves_lr_type(self):
|
||||
# Ensures that tensor lrs are preserved, preventing recompilations.
|
||||
types = [type(group["lr"]) for group in self.opt.param_groups]
|
||||
scheduler = ReduceLROnPlateau(self.opt, mode="min", patience=0)
|
||||
scheduler.step(1.0)
|
||||
scheduler.step(2.0) # Triggers scheduler._reduce_lr
|
||||
for group, type_ in zip(self.opt.param_groups, types):
|
||||
self.assertEqual(type(group["lr"]), type_)
|
||||
|
||||
def test_sequentiallr1(self):
|
||||
epochs = 19
|
||||
schedulers = [None] * 2
|
||||
@ -822,6 +831,27 @@ class TestLRScheduler(TestCase):
|
||||
targets = [single_targets, [x * 10 for x in single_targets]]
|
||||
self._test_get_last_lr(scheduler, targets, epochs)
|
||||
|
||||
def test_sequentiallr_does_not_alias_lr_and_initial_lr(self):
|
||||
# The TestLRScheduler object uses self.opt to avoid instantiating a new optimizer for each test.
|
||||
# self.opt has a float lr, and we need to use a Tensor lr to ensure that a former SequentialLR bug is fixed.
|
||||
# For more context, see https://github.com/pytorch/pytorch/issues/162359
|
||||
old_opt = self.opt
|
||||
lr = torch.tensor(2.0)
|
||||
self.opt = SGD(self.net.parameters(), lr=lr)
|
||||
milestone = 4
|
||||
epochs = 8
|
||||
start, end = 0.1, 0.8
|
||||
|
||||
schedulers = [
|
||||
LinearLR(self.opt, start, end, total_iters=milestone),
|
||||
LinearLR(self.opt, end, start, total_iters=epochs - milestone),
|
||||
]
|
||||
targets = [[0.2, 0.55, 0.9, 1.25, 1.6, 1.25, 0.9, 0.55]]
|
||||
|
||||
scheduler = SequentialLR(self.opt, schedulers, milestones=[milestone])
|
||||
self._test(scheduler, targets, epochs)
|
||||
self.opt = old_opt
|
||||
|
||||
def test_chained_lr2_get_last_lr_before_step(self):
|
||||
schedulers = [
|
||||
LinearLR(self.opt, start_factor=0.4, total_iters=3),
|
||||
|
@ -79,6 +79,16 @@ def _format_param(name: str, optimizer: Optimizer, param):
|
||||
return list(map(_copy, param))
|
||||
|
||||
|
||||
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
|
||||
"""Set param_group[key] to val without aliasing or assignment when they're both tensors.
|
||||
Raises a KeyError if param_group[key] does not exist.
|
||||
"""
|
||||
if isinstance(param_group[key], Tensor):
|
||||
param_group[key].fill_(_to_scalar(val))
|
||||
else:
|
||||
param_group[key] = val
|
||||
|
||||
|
||||
class LRScheduler:
|
||||
r"""Adjusts the learning rate during optimization."""
|
||||
|
||||
@ -219,10 +229,7 @@ class LRScheduler:
|
||||
values = self.get_lr()
|
||||
|
||||
for param_group, lr in zip(self.optimizer.param_groups, values):
|
||||
if isinstance(param_group["lr"], Tensor):
|
||||
param_group["lr"].fill_(_to_scalar(lr))
|
||||
else:
|
||||
param_group["lr"] = lr
|
||||
_update_param_group_val(param_group, "lr", lr)
|
||||
|
||||
self._last_lr: list[float] = [
|
||||
group["lr"] for group in self.optimizer.param_groups
|
||||
@ -889,7 +896,7 @@ class SequentialLR(LRScheduler):
|
||||
|
||||
# Reset learning rates back to initial values
|
||||
for group in self.optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"]
|
||||
_update_param_group_val(group, "lr", group["initial_lr"])
|
||||
|
||||
# "Undo" the step performed by other schedulers
|
||||
self.recursive_undo()
|
||||
@ -1385,7 +1392,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
old_lr = float(param_group["lr"])
|
||||
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
||||
if old_lr - new_lr > self.eps:
|
||||
param_group["lr"] = new_lr
|
||||
_update_param_group_val(param_group, "lr", new_lr)
|
||||
|
||||
@property
|
||||
def in_cooldown(self): # noqa: D102
|
||||
@ -1860,7 +1867,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
|
||||
with _enable_get_lr_call(self):
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group["lr"] = lr
|
||||
_update_param_group_val(param_group, "lr", lr)
|
||||
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
|
Reference in New Issue
Block a user