mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fixes #92007 ## Changes - Add script to generate plots for `lr_scheduler` - Add plots to `lr_scheduler` docs - Add example section if it missing in `lr_scheduler` docs ## Test Result ### LambdaLR  ### MultiplicativeLR  ### StepLR  ### MultiStepLR  ### ConstantLR  ### LinearLR  ### ExponentialLR  ### PolynomialLR  ### CosineAnnealingLR  ### ChainedScheduler  ### SequentialLR  ### ReduceLROnPlateau  ### CyclicLR  ### OneCycleLR  ### CosineAnnealingWarmRestarts  Pull Request resolved: https://github.com/pytorch/pytorch/pull/149189 Approved by: https://github.com/janeyx99
97 lines
2.8 KiB
Python
97 lines
2.8 KiB
Python
from pathlib import Path
|
|
|
|
import matplotlib
|
|
from matplotlib import pyplot as plt
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch.optim.lr_scheduler import (
|
|
ChainedScheduler,
|
|
ConstantLR,
|
|
CosineAnnealingLR,
|
|
CosineAnnealingWarmRestarts,
|
|
CyclicLR,
|
|
ExponentialLR,
|
|
LambdaLR,
|
|
LinearLR,
|
|
MultiplicativeLR,
|
|
MultiStepLR,
|
|
OneCycleLR,
|
|
PolynomialLR,
|
|
ReduceLROnPlateau,
|
|
SequentialLR,
|
|
StepLR,
|
|
)
|
|
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
LR_SCHEDULER_IMAGE_PATH = Path(__file__).parent / "lr_scheduler_images"
|
|
|
|
if not LR_SCHEDULER_IMAGE_PATH.exists():
|
|
LR_SCHEDULER_IMAGE_PATH.mkdir()
|
|
|
|
model = torch.nn.Linear(10, 1)
|
|
optimizer = optim.SGD(model.parameters(), lr=0.05)
|
|
|
|
num_epochs = 100
|
|
|
|
scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=num_epochs // 5)
|
|
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
|
|
|
|
schedulers = [
|
|
(lambda opt: LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)),
|
|
(lambda opt: MultiplicativeLR(opt, lr_lambda=lambda epoch: 0.95)),
|
|
(lambda opt: StepLR(opt, step_size=30, gamma=0.1)),
|
|
(lambda opt: MultiStepLR(opt, milestones=[30, 80], gamma=0.1)),
|
|
(lambda opt: ConstantLR(opt, factor=0.5, total_iters=40)),
|
|
(lambda opt: LinearLR(opt, start_factor=0.05, total_iters=40)),
|
|
(lambda opt: ExponentialLR(opt, gamma=0.95)),
|
|
(lambda opt: PolynomialLR(opt, total_iters=num_epochs / 2, power=0.9)),
|
|
(lambda opt: CosineAnnealingLR(opt, T_max=num_epochs)),
|
|
(lambda opt: CosineAnnealingWarmRestarts(opt, T_0=20)),
|
|
(lambda opt: CyclicLR(opt, base_lr=0.01, max_lr=0.1, step_size_up=10)),
|
|
(lambda opt: OneCycleLR(opt, max_lr=0.01, epochs=10, steps_per_epoch=10)),
|
|
(lambda opt: ReduceLROnPlateau(opt, mode="min")),
|
|
(lambda opt: ChainedScheduler([scheduler1, scheduler2])),
|
|
(
|
|
lambda opt: SequentialLR(
|
|
opt, schedulers=[scheduler1, scheduler2], milestones=[num_epochs // 5]
|
|
)
|
|
),
|
|
]
|
|
|
|
|
|
def plot_function(scheduler):
|
|
plt.clf()
|
|
plt.grid(color="k", alpha=0.2, linestyle="--")
|
|
lrs = []
|
|
optimizer.param_groups[0]["lr"] = 0.05
|
|
scheduler = scheduler(optimizer)
|
|
|
|
plot_path = LR_SCHEDULER_IMAGE_PATH / f"{scheduler.__class__.__name__}.png"
|
|
if plot_path.exists():
|
|
return
|
|
|
|
for _ in range(num_epochs):
|
|
lrs.append(optimizer.param_groups[0]["lr"])
|
|
if isinstance(scheduler, ReduceLROnPlateau):
|
|
val_loss = torch.randn(1).item()
|
|
scheduler.step(val_loss)
|
|
else:
|
|
scheduler.step()
|
|
|
|
plt.plot(range(num_epochs), lrs)
|
|
plt.title(f"Learning Rate: {scheduler.__class__.__name__}")
|
|
plt.xlabel("Epoch")
|
|
plt.ylabel("Learning Rate")
|
|
plt.xlim([0, num_epochs])
|
|
plt.savefig(plot_path)
|
|
print(
|
|
f"Saved learning rate scheduler image for {scheduler.__class__.__name__} at {plot_path}"
|
|
)
|
|
|
|
|
|
for scheduler in schedulers:
|
|
plot_function(scheduler)
|