[optim] prevent problematic tensor aliasing in lr_scheduler (#163098)

Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.

Supersedes #162360
Fixes #162359
Fixes #163093

While putting #162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.

There are several bugs resembling the one fixed in #162360. I added a helper to fix these:
```python
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
    """Set param_group[key] to val without aliasing or assignment when they're both tensors.
    Raises a KeyError if param_group[key] does not exist.
    """
    if isinstance(param_group[key], Tensor):
        param_group[key].fill_(_to_scalar(val))
    else:
        param_group[key] = val
```

And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue:
```python
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
    param_group["lr"] = lr
```

But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163098
Approved by: https://github.com/janeyx99
This commit is contained in:
Filip
2025-09-17 13:40:23 +00:00
committed by PyTorch MergeBot
parent 607489f3d0
commit bc38c5baa1
2 changed files with 44 additions and 7 deletions

View File

@ -79,6 +79,16 @@ def _format_param(name: str, optimizer: Optimizer, param):
return list(map(_copy, param))
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
"""Set param_group[key] to val without aliasing or assignment when they're both tensors.
Raises a KeyError if param_group[key] does not exist.
"""
if isinstance(param_group[key], Tensor):
param_group[key].fill_(_to_scalar(val))
else:
param_group[key] = val
class LRScheduler:
r"""Adjusts the learning rate during optimization."""
@ -219,10 +229,7 @@ class LRScheduler:
values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values):
if isinstance(param_group["lr"], Tensor):
param_group["lr"].fill_(_to_scalar(lr))
else:
param_group["lr"] = lr
_update_param_group_val(param_group, "lr", lr)
self._last_lr: list[float] = [
group["lr"] for group in self.optimizer.param_groups
@ -889,7 +896,7 @@ class SequentialLR(LRScheduler):
# Reset learning rates back to initial values
for group in self.optimizer.param_groups:
group["lr"] = group["initial_lr"]
_update_param_group_val(group, "lr", group["initial_lr"])
# "Undo" the step performed by other schedulers
self.recursive_undo()
@ -1385,7 +1392,7 @@ class ReduceLROnPlateau(LRScheduler):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
param_group["lr"] = new_lr
_update_param_group_val(param_group, "lr", new_lr)
@property
def in_cooldown(self): # noqa: D102
@ -1860,7 +1867,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
with _enable_get_lr_call(self):
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
_update_param_group_val(param_group, "lr", lr)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]