mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add load_state_dict
hint doc about invoke order work with lr_scheduler (#149942)
Fixes #119168 ## Test Result  Pull Request resolved: https://github.com/pytorch/pytorch/pull/149942 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
781ba0ac9d
commit
82dc3457e0
@ -858,6 +858,10 @@ class Optimizer:
|
||||
state_dict (dict): optimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
|
||||
.. warning::
|
||||
Make sure this method is called after initializing :class:`torch.optim.lr_scheduler.LRScheduler`,
|
||||
as calling it beforehand will overwrite the loaded learning rates.
|
||||
|
||||
.. note::
|
||||
The names of the parameters (if they exist under the "param_names" key of each param group
|
||||
in :meth:`state_dict`) will not affect the loading process.
|
||||
@ -868,6 +872,18 @@ class Optimizer:
|
||||
If ``param_names`` exist in loaded state dict ``param_groups`` they will be saved and override
|
||||
the current names, if present, in the optimizer state. If they do not exist in loaded state dict,
|
||||
the optimizer ``param_names`` will remain unchanged.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> model = torch.nn.Linear(10, 10)
|
||||
>>> optim = torch.optim.SGD(model.parameters(), lr=3e-4)
|
||||
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.1, end_factor=1, total_iters=20)
|
||||
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=80, eta_min=3e-5)
|
||||
>>> lr = torch.optim.lr_scheduler.SequentialLR(optim, schedulers=[scheduler1, scheduler2], milestones=[20])
|
||||
>>> lr.load_state_dict(torch.load('./save_seq.pt'))
|
||||
>>> # now load the optimizer checkpoint after loading the LRScheduler
|
||||
>>> optim.load_state_dict(torch.load('./save_optim.pt'))
|
||||
|
||||
"""
|
||||
# shallow copy, to be consistent with module API
|
||||
state_dict = state_dict.copy()
|
||||
|
Reference in New Issue
Block a user