mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
include scheduler_on_plateau in optim.h (#121722)
Fixes #121593 Co-authored-by: Jane Xu <janeyx@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121722 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ceff2205e9
commit
ccfc87b199
@ -746,6 +746,7 @@ torch_cpp_srcs = [
|
||||
"torch/csrc/api/src/optim/serialize.cpp",
|
||||
"torch/csrc/api/src/optim/sgd.cpp",
|
||||
"torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp",
|
||||
"torch/csrc/api/src/optim/schedulers/reduce_on_plateau_scheduler.cpp",
|
||||
"torch/csrc/api/src/optim/schedulers/step_lr.cpp",
|
||||
"torch/csrc/api/src/serialize/input-archive.cpp",
|
||||
"torch/csrc/api/src/serialize/output-archive.cpp",
|
||||
|
@ -510,6 +510,38 @@ void check_lr_change(
|
||||
}
|
||||
}
|
||||
|
||||
// Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
|
||||
// which does not inherit from LRScheduler and requires a metrics
|
||||
// input to step().
|
||||
void check_lr_change_for_reduce_on_plateau(
|
||||
Optimizer& optimizer,
|
||||
ReduceLROnPlateauScheduler& lr_scheduler,
|
||||
std::map<unsigned, double> expected_epoch_lrs) {
|
||||
// Find maximum epoch in map
|
||||
unsigned kIterations = std::max_element(
|
||||
expected_epoch_lrs.begin(),
|
||||
expected_epoch_lrs.end(),
|
||||
[](const std::pair<unsigned, double>& a,
|
||||
const std::pair<unsigned, double>& b) -> bool {
|
||||
return a.second > b.second;
|
||||
})
|
||||
->first;
|
||||
|
||||
for (unsigned i = 0; i <= kIterations; i++) {
|
||||
const auto epoch_iter = expected_epoch_lrs.find(i);
|
||||
if (epoch_iter != expected_epoch_lrs.end()) {
|
||||
// Compare the similarity of the two floating point learning rates
|
||||
ASSERT_TRUE(
|
||||
fabs(
|
||||
epoch_iter->second -
|
||||
optimizer.param_groups()[0].options().get_lr()) <
|
||||
std::numeric_limits<double>::epsilon());
|
||||
}
|
||||
optimizer.step();
|
||||
lr_scheduler.step(5.0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimTest, CheckLRChange_StepLR_Adam) {
|
||||
torch::Tensor parameters = torch::zeros({1});
|
||||
auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
|
||||
@ -523,3 +555,21 @@ TEST(OptimTest, CheckLRChange_StepLR_Adam) {
|
||||
|
||||
check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
|
||||
}
|
||||
|
||||
TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
|
||||
torch::Tensor parameters = torch::zeros({1});
|
||||
auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
|
||||
const float factor = 0.5;
|
||||
const int patience = 20;
|
||||
ReduceLROnPlateauScheduler reduce_lr_on_plateau_scheduler(
|
||||
optimizer,
|
||||
ReduceLROnPlateauScheduler::SchedulerMode::min,
|
||||
factor,
|
||||
patience);
|
||||
|
||||
// The learning rate should have halved at epoch 20
|
||||
const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
|
||||
|
||||
check_lr_change_for_reduce_on_plateau(
|
||||
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
|
||||
}
|
||||
|
@ -9,4 +9,5 @@
|
||||
#include <torch/optim/sgd.h>
|
||||
|
||||
#include <torch/optim/schedulers/lr_scheduler.h>
|
||||
#include <torch/optim/schedulers/reduce_on_plateau_scheduler.h>
|
||||
#include <torch/optim/schedulers/step_lr.h>
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/schedulers/lr_scheduler.h>
|
||||
|
||||
#include <torch/csrc/Export.h>
|
||||
|
||||
|
Reference in New Issue
Block a user