Clean up duplicated code in lr_scheduler (#150984)

## Changes

- Remove duplicated code in `ReduceLROnPlateau`
- Remove redundant `noqa` comment

## Test Result

```bash
pytest test/optim/test_lrscheduler.py
```

![image](https://github.com/user-attachments/assets/37f91f31-0e77-4abf-9dd1-75538c0f0792)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150984
Approved by: https://github.com/janeyx99
This commit is contained in:
zeshengzong
2025-04-13 09:18:47 +00:00
committed by PyTorch MergeBot
parent b59f3d3ae0
commit 304633152c

View File

@ -1214,15 +1214,7 @@ class ReduceLROnPlateau(LRScheduler):
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
self.cooldown = cooldown
self.cooldown_counter = 0
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.best: float
self.num_bad_epochs: int
self.mode_worse: float # the worse value for the chosen mode
self.eps = eps
self.last_epoch = 0
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
@ -1310,6 +1302,7 @@ class ReduceLROnPlateau(LRScheduler):
if threshold_mode not in {"rel", "abs"}:
raise ValueError("threshold mode " + threshold_mode + " is unknown!")
# the worse value for the chosen mode
if mode == "min":
self.mode_worse = inf
else: # mode == 'max':
@ -1319,11 +1312,6 @@ class ReduceLROnPlateau(LRScheduler):
self.threshold = threshold
self.threshold_mode = threshold_mode
def state_dict(self): # noqa: D102
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict):
"""Load the scheduler's state."""
self.__dict__.update(state_dict)
@ -2007,7 +1995,7 @@ class OneCycleLR(LRScheduler):
if step_num > self.total_steps:
raise ValueError(
f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032
f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}"
)
for group in self.optimizer.param_groups: