Add scripts to generate plots of LRSchedulers (#149189)

Fixes #92007

## Changes

- Add script to generate plots for `lr_scheduler`
- Add plots to `lr_scheduler` docs
- Add example section if it missing in `lr_scheduler` docs

## Test Result

### LambdaLR

![image](https://github.com/user-attachments/assets/37fc0894-e2ec-48f2-a2d6-3514e51e1ea2)

### MultiplicativeLR

![image](https://github.com/user-attachments/assets/2122b3a0-a4ce-42c7-bb45-559c1fc73e0f)

### StepLR

![image](https://github.com/user-attachments/assets/47bc9d96-4b60-4586-a000-f213583bbe8f)

### MultiStepLR

![image](https://github.com/user-attachments/assets/c822b849-d5be-4b94-aa7a-0017a2c9ff15)

### ConstantLR

![image](https://github.com/user-attachments/assets/83107cdd-7b00-44a6-b09d-e8ee849b4a12)

### LinearLR

![image](https://github.com/user-attachments/assets/60190105-691a-4101-8966-5b0c396093a4)

### ExponentialLR

![image](https://github.com/user-attachments/assets/dfcbcbca-89e5-4a2f-b1bd-33e25d2405ec)

### PolynomialLR

![image](https://github.com/user-attachments/assets/7c3d4fce-c846-40a0-b62e-f3e81c7e08bd)

### CosineAnnealingLR

![image](https://github.com/user-attachments/assets/26712769-dde9-4faa-b61b-e23c51daef50)

### ChainedScheduler

![image](https://github.com/user-attachments/assets/20734a8b-e939-424f-b45a-773f86f020b1)

### SequentialLR

![image](https://github.com/user-attachments/assets/2cd3ed67-2a0a-4c42-9ad2-e0be090d3751)

### ReduceLROnPlateau

![image](https://github.com/user-attachments/assets/b77f641e-4810-450d-b2cd-8b3f134ea188)

### CyclicLR

![image](https://github.com/user-attachments/assets/29b8666f-41b3-45e4-9159-6929074e6108)

### OneCycleLR

![image](https://github.com/user-attachments/assets/d5b683ef-41e8-4ca8-9fe8-0f1e6b433866)

### CosineAnnealingWarmRestarts

![image](https://github.com/user-attachments/assets/1d45ea80-dea8-494d-a8ab-e9cfc94c55d6)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149189
Approved by: https://github.com/janeyx99
This commit is contained in:
zeshengzong
2025-04-14 09:53:35 +00:00
committed by PyTorch MergeBot
parent 5a64476ed6
commit 5eebcb991a
4 changed files with 198 additions and 32 deletions

1
.gitignore vendored
View File

@ -178,6 +178,7 @@ compile_commands.json
*.egg-info/
docs/source/scripts/activation_images/
docs/source/scripts/quantization_backend_configs/
docs/source/scripts/lr_scheduler_images/
## General

View File

@ -16,6 +16,7 @@ help:
figures:
@$(PYCMD) source/scripts/build_activation_images.py
@$(PYCMD) source/scripts/build_quantization_configs.py
@$(PYCMD) source/scripts/build_lr_scheduler_images.py
onnx:
@$(PYCMD) source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py

View File

@ -0,0 +1,96 @@
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import (
ChainedScheduler,
ConstantLR,
CosineAnnealingLR,
CosineAnnealingWarmRestarts,
CyclicLR,
ExponentialLR,
LambdaLR,
LinearLR,
MultiplicativeLR,
MultiStepLR,
OneCycleLR,
PolynomialLR,
ReduceLROnPlateau,
SequentialLR,
StepLR,
)
matplotlib.use("Agg")
LR_SCHEDULER_IMAGE_PATH = Path(__file__).parent / "lr_scheduler_images"
if not LR_SCHEDULER_IMAGE_PATH.exists():
LR_SCHEDULER_IMAGE_PATH.mkdir()
model = torch.nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.05)
num_epochs = 100
scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=num_epochs // 5)
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
schedulers = [
(lambda opt: LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)),
(lambda opt: MultiplicativeLR(opt, lr_lambda=lambda epoch: 0.95)),
(lambda opt: StepLR(opt, step_size=30, gamma=0.1)),
(lambda opt: MultiStepLR(opt, milestones=[30, 80], gamma=0.1)),
(lambda opt: ConstantLR(opt, factor=0.5, total_iters=40)),
(lambda opt: LinearLR(opt, start_factor=0.05, total_iters=40)),
(lambda opt: ExponentialLR(opt, gamma=0.95)),
(lambda opt: PolynomialLR(opt, total_iters=num_epochs / 2, power=0.9)),
(lambda opt: CosineAnnealingLR(opt, T_max=num_epochs)),
(lambda opt: CosineAnnealingWarmRestarts(opt, T_0=20)),
(lambda opt: CyclicLR(opt, base_lr=0.01, max_lr=0.1, step_size_up=10)),
(lambda opt: OneCycleLR(opt, max_lr=0.01, epochs=10, steps_per_epoch=10)),
(lambda opt: ReduceLROnPlateau(opt, mode="min")),
(lambda opt: ChainedScheduler([scheduler1, scheduler2])),
(
lambda opt: SequentialLR(
opt, schedulers=[scheduler1, scheduler2], milestones=[num_epochs // 5]
)
),
]
def plot_function(scheduler):
plt.clf()
plt.grid(color="k", alpha=0.2, linestyle="--")
lrs = []
optimizer.param_groups[0]["lr"] = 0.05
scheduler = scheduler(optimizer)
plot_path = LR_SCHEDULER_IMAGE_PATH / f"{scheduler.__class__.__name__}.png"
if plot_path.exists():
return
for _ in range(num_epochs):
lrs.append(optimizer.param_groups[0]["lr"])
if isinstance(scheduler, ReduceLROnPlateau):
val_loss = torch.randn(1).item()
scheduler.step(val_loss)
else:
scheduler.step()
plt.plot(range(num_epochs), lrs)
plt.title(f"Learning Rate: {scheduler.__class__.__name__}")
plt.xlabel("Epoch")
plt.ylabel("Learning Rate")
plt.xlim([0, num_epochs])
plt.savefig(plot_path)
print(
f"Saved learning rate scheduler image for {scheduler.__class__.__name__} at {plot_path}"
)
for scheduler in schedulers:
plot_function(scheduler)

View File

@ -257,13 +257,23 @@ class LambdaLR(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer has two groups.
>>> num_epochs = 100
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> for epoch in range(num_epochs):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
>>>
>>> # Alternatively, you can use a single lambda function for all groups.
>>> scheduler = LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)
>>> for epoch in range(num_epochs):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/LambdaLR.png
"""
def __init__(
@ -357,6 +367,8 @@ class MultiplicativeLR(LRScheduler):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/MultiplicativeLR.png
"""
def __init__(
@ -454,6 +466,8 @@ class StepLR(LRScheduler):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/StepLR.png
"""
def __init__(
@ -506,6 +520,8 @@ class MultiStepLR(LRScheduler):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/MultiStepLR.png
"""
def __init__(
@ -560,12 +576,15 @@ class ConstantLR(LRScheduler):
>>> # lr = 0.025 if epoch == 1
>>> # lr = 0.025 if epoch == 2
>>> # lr = 0.025 if epoch == 3
>>> # lr = 0.05 if epoch >= 4
>>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4)
>>> # ...
>>> # lr = 0.05 if epoch >= 40
>>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=40)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/ConstantLR.png
"""
def __init__(
@ -627,16 +646,19 @@ class LinearLR(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.025 if epoch == 0
>>> # lr = 0.03125 if epoch == 1
>>> # lr = 0.0375 if epoch == 2
>>> # lr = 0.04375 if epoch == 3
>>> # lr = 0.05 if epoch >= 4
>>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
>>> # lr = 0.003687 if epoch == 0
>>> # lr = 0.004875 if epoch == 1
>>> # lr = 0.006062 if epoch == 2
>>> # lr = 0.00725 if epoch == 3
>>> # ...
>>> # lr = 0.05 if epoch >= 40
>>> scheduler = LinearLR(optimizer, start_factor=0.05, total_iters=40)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/LinearLR.png
"""
def __init__(
@ -709,6 +731,16 @@ class ExponentialLR(LRScheduler):
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> scheduler = ExponentialLR(optimizer, gamma=0.95)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/ExponentialLR.png
"""
def __init__(
@ -746,19 +778,23 @@ class SequentialLR(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 1. for all groups
>>> # lr = 0.1 if epoch == 0
>>> # lr = 0.1 if epoch == 1
>>> # lr = 0.9 if epoch == 2
>>> # lr = 0.81 if epoch == 3
>>> # lr = 0.729 if epoch == 4
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.005 if epoch == 0
>>> # lr = 0.005 if epoch == 1
>>> # lr = 0.005 if epoch == 2
>>> # ...
>>> # lr = 0.05 if epoch == 20
>>> # lr = 0.045 if epoch == 21
>>> # lr = 0.0405 if epoch == 22
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2])
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[20])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/SequentialLR.png
"""
def __init__(
@ -887,17 +923,19 @@ class PolynomialLR(LRScheduler):
Example:
>>> # xdoctest: +SKIP("undefined vars")
>>> # Assuming optimizer uses lr = 0.001 for all groups
>>> # lr = 0.001 if epoch == 0
>>> # lr = 0.00075 if epoch == 1
>>> # lr = 0.00050 if epoch == 2
>>> # lr = 0.00025 if epoch == 3
>>> # lr = 0.0 if epoch >= 4
>>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0)
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.0490 if epoch == 0
>>> # lr = 0.0481 if epoch == 1
>>> # lr = 0.0472 if epoch == 2
>>> # ...
>>> # lr = 0.0 if epoch >= 50
>>> scheduler = PolynomialLR(optimizer, total_iters=50, power=0.9)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/PolynomialLR.png
"""
def __init__(
@ -972,6 +1010,17 @@ class CosineAnnealingLR(LRScheduler):
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
Example:
>>> # xdoctest: +SKIP
>>> num_epochs = 100
>>> scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
>>> for epoch in range(num_epochs):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/CosineAnnealingLR.png
"""
def __init__(
@ -1035,19 +1084,23 @@ class ChainedScheduler(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 1. for all groups
>>> # lr = 0.09 if epoch == 0
>>> # lr = 0.081 if epoch == 1
>>> # lr = 0.729 if epoch == 2
>>> # lr = 0.6561 if epoch == 3
>>> # lr = 0.59049 if epoch >= 4
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch == 0
>>> # lr = 0.0450 if epoch == 1
>>> # lr = 0.0405 if epoch == 2
>>> # ...
>>> # lr = 0.00675 if epoch == 19
>>> # lr = 0.06078 if epoch == 20
>>> # lr = 0.05470 if epoch == 21
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/ChainedScheduler.png
"""
def __init__(
@ -1179,6 +1232,8 @@ class ReduceLROnPlateau(LRScheduler):
>>> val_loss = validate(...)
>>> # Note that step should be called after validate()
>>> scheduler.step(val_loss)
.. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png
"""
def __init__(
@ -1401,13 +1456,14 @@ class CyclicLR(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=10)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/CyclicLR.png
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
@ -1620,6 +1676,17 @@ class CosineAnnealingWarmRestarts(LRScheduler):
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/CosineAnnealingWarmRestarts.png
"""
def __init__(
@ -1825,6 +1892,7 @@ class OneCycleLR(LRScheduler):
>>> optimizer.step()
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/OneCycleLR.png
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120