mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Clean up duplicated code in lr_scheduler (#150984)
## Changes - Remove duplicated code in `ReduceLROnPlateau` - Remove redundant `noqa` comment ## Test Result ```bash pytest test/optim/test_lrscheduler.py ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/150984 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
b59f3d3ae0
commit
304633152c
@ -1214,15 +1214,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
||||
|
||||
self.patience = patience
|
||||
|
||||
self.cooldown = cooldown
|
||||
self.cooldown_counter = 0
|
||||
self.mode = mode
|
||||
self.threshold = threshold
|
||||
self.threshold_mode = threshold_mode
|
||||
self.best: float
|
||||
self.num_bad_epochs: int
|
||||
self.mode_worse: float # the worse value for the chosen mode
|
||||
self.eps = eps
|
||||
self.last_epoch = 0
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||
@ -1310,6 +1302,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
if threshold_mode not in {"rel", "abs"}:
|
||||
raise ValueError("threshold mode " + threshold_mode + " is unknown!")
|
||||
|
||||
# the worse value for the chosen mode
|
||||
if mode == "min":
|
||||
self.mode_worse = inf
|
||||
else: # mode == 'max':
|
||||
@ -1319,11 +1312,6 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
self.threshold = threshold
|
||||
self.threshold_mode = threshold_mode
|
||||
|
||||
def state_dict(self): # noqa: D102
|
||||
return {
|
||||
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load the scheduler's state."""
|
||||
self.__dict__.update(state_dict)
|
||||
@ -2007,7 +1995,7 @@ class OneCycleLR(LRScheduler):
|
||||
|
||||
if step_num > self.total_steps:
|
||||
raise ValueError(
|
||||
f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032
|
||||
f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}"
|
||||
)
|
||||
|
||||
for group in self.optimizer.param_groups:
|
||||
|
Reference in New Issue
Block a user