mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add scripts to generate plots of LRSchedulers (#149189)
Fixes #92007 ## Changes - Add script to generate plots for `lr_scheduler` - Add plots to `lr_scheduler` docs - Add example section if it missing in `lr_scheduler` docs ## Test Result ### LambdaLR  ### MultiplicativeLR  ### StepLR  ### MultiStepLR  ### ConstantLR  ### LinearLR  ### ExponentialLR  ### PolynomialLR  ### CosineAnnealingLR  ### ChainedScheduler  ### SequentialLR  ### ReduceLROnPlateau  ### CyclicLR  ### OneCycleLR  ### CosineAnnealingWarmRestarts  Pull Request resolved: https://github.com/pytorch/pytorch/pull/149189 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a64476ed6
commit
5eebcb991a
1
.gitignore
vendored
1
.gitignore
vendored
@ -178,6 +178,7 @@ compile_commands.json
|
||||
*.egg-info/
|
||||
docs/source/scripts/activation_images/
|
||||
docs/source/scripts/quantization_backend_configs/
|
||||
docs/source/scripts/lr_scheduler_images/
|
||||
|
||||
## General
|
||||
|
||||
|
@ -16,6 +16,7 @@ help:
|
||||
figures:
|
||||
@$(PYCMD) source/scripts/build_activation_images.py
|
||||
@$(PYCMD) source/scripts/build_quantization_configs.py
|
||||
@$(PYCMD) source/scripts/build_lr_scheduler_images.py
|
||||
|
||||
onnx:
|
||||
@$(PYCMD) source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py
|
||||
|
96
docs/source/scripts/build_lr_scheduler_images.py
Normal file
96
docs/source/scripts/build_lr_scheduler_images.py
Normal file
@ -0,0 +1,96 @@
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import (
|
||||
ChainedScheduler,
|
||||
ConstantLR,
|
||||
CosineAnnealingLR,
|
||||
CosineAnnealingWarmRestarts,
|
||||
CyclicLR,
|
||||
ExponentialLR,
|
||||
LambdaLR,
|
||||
LinearLR,
|
||||
MultiplicativeLR,
|
||||
MultiStepLR,
|
||||
OneCycleLR,
|
||||
PolynomialLR,
|
||||
ReduceLROnPlateau,
|
||||
SequentialLR,
|
||||
StepLR,
|
||||
)
|
||||
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
LR_SCHEDULER_IMAGE_PATH = Path(__file__).parent / "lr_scheduler_images"
|
||||
|
||||
if not LR_SCHEDULER_IMAGE_PATH.exists():
|
||||
LR_SCHEDULER_IMAGE_PATH.mkdir()
|
||||
|
||||
model = torch.nn.Linear(10, 1)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.05)
|
||||
|
||||
num_epochs = 100
|
||||
|
||||
scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=num_epochs // 5)
|
||||
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
|
||||
|
||||
schedulers = [
|
||||
(lambda opt: LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)),
|
||||
(lambda opt: MultiplicativeLR(opt, lr_lambda=lambda epoch: 0.95)),
|
||||
(lambda opt: StepLR(opt, step_size=30, gamma=0.1)),
|
||||
(lambda opt: MultiStepLR(opt, milestones=[30, 80], gamma=0.1)),
|
||||
(lambda opt: ConstantLR(opt, factor=0.5, total_iters=40)),
|
||||
(lambda opt: LinearLR(opt, start_factor=0.05, total_iters=40)),
|
||||
(lambda opt: ExponentialLR(opt, gamma=0.95)),
|
||||
(lambda opt: PolynomialLR(opt, total_iters=num_epochs / 2, power=0.9)),
|
||||
(lambda opt: CosineAnnealingLR(opt, T_max=num_epochs)),
|
||||
(lambda opt: CosineAnnealingWarmRestarts(opt, T_0=20)),
|
||||
(lambda opt: CyclicLR(opt, base_lr=0.01, max_lr=0.1, step_size_up=10)),
|
||||
(lambda opt: OneCycleLR(opt, max_lr=0.01, epochs=10, steps_per_epoch=10)),
|
||||
(lambda opt: ReduceLROnPlateau(opt, mode="min")),
|
||||
(lambda opt: ChainedScheduler([scheduler1, scheduler2])),
|
||||
(
|
||||
lambda opt: SequentialLR(
|
||||
opt, schedulers=[scheduler1, scheduler2], milestones=[num_epochs // 5]
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def plot_function(scheduler):
|
||||
plt.clf()
|
||||
plt.grid(color="k", alpha=0.2, linestyle="--")
|
||||
lrs = []
|
||||
optimizer.param_groups[0]["lr"] = 0.05
|
||||
scheduler = scheduler(optimizer)
|
||||
|
||||
plot_path = LR_SCHEDULER_IMAGE_PATH / f"{scheduler.__class__.__name__}.png"
|
||||
if plot_path.exists():
|
||||
return
|
||||
|
||||
for _ in range(num_epochs):
|
||||
lrs.append(optimizer.param_groups[0]["lr"])
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
val_loss = torch.randn(1).item()
|
||||
scheduler.step(val_loss)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
plt.plot(range(num_epochs), lrs)
|
||||
plt.title(f"Learning Rate: {scheduler.__class__.__name__}")
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Learning Rate")
|
||||
plt.xlim([0, num_epochs])
|
||||
plt.savefig(plot_path)
|
||||
print(
|
||||
f"Saved learning rate scheduler image for {scheduler.__class__.__name__} at {plot_path}"
|
||||
)
|
||||
|
||||
|
||||
for scheduler in schedulers:
|
||||
plot_function(scheduler)
|
@ -257,13 +257,23 @@ class LambdaLR(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Assuming optimizer has two groups.
|
||||
>>> num_epochs = 100
|
||||
>>> lambda1 = lambda epoch: epoch // 30
|
||||
>>> lambda2 = lambda epoch: 0.95 ** epoch
|
||||
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
|
||||
>>> for epoch in range(100):
|
||||
>>> for epoch in range(num_epochs):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
>>>
|
||||
>>> # Alternatively, you can use a single lambda function for all groups.
|
||||
>>> scheduler = LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)
|
||||
>>> for epoch in range(num_epochs):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/LambdaLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -357,6 +367,8 @@ class MultiplicativeLR(LRScheduler):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/MultiplicativeLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -454,6 +466,8 @@ class StepLR(LRScheduler):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/StepLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -506,6 +520,8 @@ class MultiStepLR(LRScheduler):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/MultiStepLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -560,12 +576,15 @@ class ConstantLR(LRScheduler):
|
||||
>>> # lr = 0.025 if epoch == 1
|
||||
>>> # lr = 0.025 if epoch == 2
|
||||
>>> # lr = 0.025 if epoch == 3
|
||||
>>> # lr = 0.05 if epoch >= 4
|
||||
>>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4)
|
||||
>>> # ...
|
||||
>>> # lr = 0.05 if epoch >= 40
|
||||
>>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=40)
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/ConstantLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -627,16 +646,19 @@ class LinearLR(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.025 if epoch == 0
|
||||
>>> # lr = 0.03125 if epoch == 1
|
||||
>>> # lr = 0.0375 if epoch == 2
|
||||
>>> # lr = 0.04375 if epoch == 3
|
||||
>>> # lr = 0.05 if epoch >= 4
|
||||
>>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
|
||||
>>> # lr = 0.003687 if epoch == 0
|
||||
>>> # lr = 0.004875 if epoch == 1
|
||||
>>> # lr = 0.006062 if epoch == 2
|
||||
>>> # lr = 0.00725 if epoch == 3
|
||||
>>> # ...
|
||||
>>> # lr = 0.05 if epoch >= 40
|
||||
>>> scheduler = LinearLR(optimizer, start_factor=0.05, total_iters=40)
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/LinearLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -709,6 +731,16 @@ class ExponentialLR(LRScheduler):
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> scheduler = ExponentialLR(optimizer, gamma=0.95)
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/ExponentialLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -746,19 +778,23 @@ class SequentialLR(LRScheduler):
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Assuming optimizer uses lr = 1. for all groups
|
||||
>>> # lr = 0.1 if epoch == 0
|
||||
>>> # lr = 0.1 if epoch == 1
|
||||
>>> # lr = 0.9 if epoch == 2
|
||||
>>> # lr = 0.81 if epoch == 3
|
||||
>>> # lr = 0.729 if epoch == 4
|
||||
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.005 if epoch == 0
|
||||
>>> # lr = 0.005 if epoch == 1
|
||||
>>> # lr = 0.005 if epoch == 2
|
||||
>>> # ...
|
||||
>>> # lr = 0.05 if epoch == 20
|
||||
>>> # lr = 0.045 if epoch == 21
|
||||
>>> # lr = 0.0405 if epoch == 22
|
||||
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
|
||||
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
|
||||
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2])
|
||||
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[20])
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/SequentialLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -887,17 +923,19 @@ class PolynomialLR(LRScheduler):
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("undefined vars")
|
||||
>>> # Assuming optimizer uses lr = 0.001 for all groups
|
||||
>>> # lr = 0.001 if epoch == 0
|
||||
>>> # lr = 0.00075 if epoch == 1
|
||||
>>> # lr = 0.00050 if epoch == 2
|
||||
>>> # lr = 0.00025 if epoch == 3
|
||||
>>> # lr = 0.0 if epoch >= 4
|
||||
>>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0)
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.0490 if epoch == 0
|
||||
>>> # lr = 0.0481 if epoch == 1
|
||||
>>> # lr = 0.0472 if epoch == 2
|
||||
>>> # ...
|
||||
>>> # lr = 0.0 if epoch >= 50
|
||||
>>> scheduler = PolynomialLR(optimizer, total_iters=50, power=0.9)
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/PolynomialLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -972,6 +1010,17 @@ class CosineAnnealingLR(LRScheduler):
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> num_epochs = 100
|
||||
>>> scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
|
||||
>>> for epoch in range(num_epochs):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/CosineAnnealingLR.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1035,19 +1084,23 @@ class ChainedScheduler(LRScheduler):
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Assuming optimizer uses lr = 1. for all groups
|
||||
>>> # lr = 0.09 if epoch == 0
|
||||
>>> # lr = 0.081 if epoch == 1
|
||||
>>> # lr = 0.729 if epoch == 2
|
||||
>>> # lr = 0.6561 if epoch == 3
|
||||
>>> # lr = 0.59049 if epoch >= 4
|
||||
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.05 if epoch == 0
|
||||
>>> # lr = 0.0450 if epoch == 1
|
||||
>>> # lr = 0.0405 if epoch == 2
|
||||
>>> # ...
|
||||
>>> # lr = 0.00675 if epoch == 19
|
||||
>>> # lr = 0.06078 if epoch == 20
|
||||
>>> # lr = 0.05470 if epoch == 21
|
||||
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
|
||||
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
|
||||
>>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/ChainedScheduler.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1179,6 +1232,8 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
>>> val_loss = validate(...)
|
||||
>>> # Note that step should be called after validate()
|
||||
>>> scheduler.step(val_loss)
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1401,13 +1456,14 @@ class CyclicLR(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=10)
|
||||
>>> data_loader = torch.utils.data.DataLoader(...)
|
||||
>>> for epoch in range(10):
|
||||
>>> for batch in data_loader:
|
||||
>>> train_batch(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/CyclicLR.png
|
||||
|
||||
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
|
||||
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
|
||||
@ -1620,6 +1676,17 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20)
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/CosineAnnealingWarmRestarts.png
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1825,6 +1892,7 @@ class OneCycleLR(LRScheduler):
|
||||
>>> optimizer.step()
|
||||
>>> scheduler.step()
|
||||
|
||||
.. image:: ../scripts/lr_scheduler_images/OneCycleLR.png
|
||||
|
||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
|
Reference in New Issue
Block a user