mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #141884 This fixes the issue for all optimizers and parameter options. A member function `overwrite_from` is added to the optimizer base class. Each optimizer then implements this function for comparing their accepted parameters to defaults. A SFINAE approach to handle the different optimizer parameters generically (in optimizer.h only) was evaluated, but I think this is easier to review and maintain. This mirrors the Python API up to one edge case. An example of the edge case is provided below. Python can distinguish between 1) Key not present in dict = "not specified" and 2) Key present in dict = "explicitly set". The C++ implementation cannot. The issue hinges on whether or not to track if a particular parameter was set by the user explicitly or not (discrepancy in the case when the constructor default is explicitly passed in). To track this seems like it will take more intervention than would be worth it (modify TORCH_ARG to keep track, use std::optional for the parameter types, use bitset tracking) and was not pursued in the current PR. I'm happy to alter the design if appropriate. ### Example of edge case hinging on CONSTRUCTOR DEFAULTS vs OPTIMIZER DEFAULTS 1. CONSTRUCTOR DEFAULTS: These are the values you get when calling AdamOptions() AdamOptions().lr() = 0.001 AdamOptions().weight_decay() = 0 AdamOptions().eps() = 1e-08 2. OPTIMIZER DEFAULTS: These are the values the user chose when creating the optimizer User's optimizer defaults: optimizer.lr() = 0.005 optimizer.weight_decay() = 0.1 optimizer.eps() = 1e-07 3. THE PROBLEM SCENARIO: User wants to add a parameter group with explicit weight_decay=0.0 User sets: weight_decay(0) 4. THE CONFUSION: Constructor default weight_decay: 0 User's explicit weight_decay: 0 Are they equal? YES Since they're equal, our overwrite_from() logic thinks: "User didn't set weight_decay explicitly, use optimizer default" 5. CURRENT BEHAVIOR: Final weight_decay: 0.1 User expected: 0 Match? ❌ NO === KEY INSIGHT === Constructor defaults are built into the C++ class definition. Optimizer defaults are chosen by the user at runtime. We want to respect the user intention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161825 Approved by: https://github.com/janeyx99
1072 lines
37 KiB
C++
1072 lines
37 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/optim_baseline.h>
|
|
#include <test/cpp/api/support.h>
|
|
|
|
#include <cmath>
|
|
#include <cstdlib>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::optim;
|
|
|
|
template <typename OptimizerClass, typename Options>
|
|
static bool test_optimizer_xor(Options options) {
|
|
torch::manual_seed(0);
|
|
|
|
Sequential model(
|
|
Linear(2, 8),
|
|
Functional(torch::sigmoid),
|
|
Linear(8, 1),
|
|
Functional(torch::sigmoid));
|
|
|
|
const int64_t kBatchSize = 200;
|
|
const int64_t kMaximumNumberOfEpochs = 3000;
|
|
|
|
OptimizerClass optimizer(model->parameters(), std::move(options));
|
|
|
|
double running_loss = 1;
|
|
int epoch = 0;
|
|
while (running_loss > 0.1) {
|
|
auto inputs = torch::empty({kBatchSize, 2});
|
|
auto labels = torch::empty({kBatchSize});
|
|
for (const auto i : c10::irange(kBatchSize)) {
|
|
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
|
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
|
|
}
|
|
|
|
inputs.set_requires_grad(true);
|
|
|
|
auto step = [&](OptimizerClass& optimizer,
|
|
Sequential model,
|
|
const torch::Tensor& inputs,
|
|
const torch::Tensor& labels) {
|
|
auto closure = [&]() {
|
|
optimizer.zero_grad();
|
|
auto x = model->forward(inputs);
|
|
auto loss = torch::binary_cross_entropy(x, labels);
|
|
loss.backward();
|
|
return loss;
|
|
};
|
|
return optimizer.step(closure);
|
|
};
|
|
|
|
torch::Tensor loss = step(optimizer, model, inputs, labels);
|
|
|
|
running_loss = running_loss * 0.99 + loss.item<double>() * 0.01;
|
|
if (epoch > kMaximumNumberOfEpochs) {
|
|
std::cout << "Loss is too high after epoch " << epoch << ": "
|
|
<< running_loss << '\n';
|
|
return false;
|
|
}
|
|
epoch++;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename Parameters>
|
|
static void assign_parameter(
|
|
const Parameters& parameters,
|
|
const char* name,
|
|
const torch::Tensor& new_tensor) {
|
|
auto parameter = parameters[name];
|
|
parameter.set_requires_grad(false);
|
|
parameter.flatten().copy_(new_tensor);
|
|
parameter.set_requires_grad(true);
|
|
}
|
|
|
|
template <typename OptimizerClass, typename Options>
|
|
static void check_exact_values(
|
|
Options options,
|
|
std::vector<std::vector<torch::Tensor>> expected_parameters) {
|
|
const size_t kIterations = 1001;
|
|
const size_t kSampleEvery = 100;
|
|
|
|
torch::manual_seed(0);
|
|
|
|
Sequential model(
|
|
Linear(2, 3),
|
|
Functional(torch::sigmoid),
|
|
Linear(3, 1),
|
|
Functional(torch::sigmoid));
|
|
|
|
model->to(torch::kFloat64);
|
|
|
|
// Use exact input values because matching random values is hard.
|
|
auto parameters = model->named_parameters();
|
|
assign_parameter(
|
|
parameters,
|
|
"0.weight",
|
|
torch::tensor(
|
|
{-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976},
|
|
torch::kFloat64));
|
|
assign_parameter(
|
|
parameters,
|
|
"0.bias",
|
|
torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
|
|
assign_parameter(
|
|
parameters,
|
|
"2.weight",
|
|
torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
|
|
assign_parameter(
|
|
parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
|
|
|
|
auto optimizer = OptimizerClass(parameters.values(), std::move(options));
|
|
torch::Tensor input =
|
|
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
|
|
.reshape({3, 2});
|
|
|
|
for (const auto i : c10::irange(kIterations)) {
|
|
optimizer.zero_grad();
|
|
auto output = model->forward(input);
|
|
auto loss = output.sum();
|
|
loss.backward();
|
|
|
|
auto closure = []() { return torch::tensor({10}); };
|
|
optimizer.step(closure);
|
|
|
|
if (i % kSampleEvery == 0) {
|
|
ASSERT_TRUE(
|
|
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
|
|
for (const auto p : c10::irange(parameters.size())) {
|
|
ASSERT_TRUE(parameters[p]->defined());
|
|
// Always compare using double dtype, regardless of the original dtype
|
|
// of the tensors
|
|
auto computed = parameters[p]->flatten().to(torch::kFloat64);
|
|
auto expected =
|
|
expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
|
|
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
|
|
std::cout << "Iteration " << i << ": " << computed
|
|
<< " != " << expected << " (parameter " << p << ")" << '\n';
|
|
ASSERT_TRUE(false);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(OptimTest, OptimizerAccessors) {
|
|
auto options = AdagradOptions(1.0);
|
|
std::vector<torch::Tensor> params;
|
|
for ([[maybe_unused]] const auto i : c10::irange(3)) {
|
|
params.push_back(torch::randn(10));
|
|
}
|
|
auto optimizer = Adagrad(params, options);
|
|
// test for defaults() method with non-const reference
|
|
auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
|
|
ASSERT_TRUE(options == options_);
|
|
// test for param_groups() with non-const reference return
|
|
auto& params_groups = optimizer.param_groups();
|
|
params_groups.emplace_back(params);
|
|
auto& params_1 = params_groups[1].params();
|
|
for (const auto i : c10::irange(params_1.size())) {
|
|
torch::equal(params[i], params_1[i]);
|
|
}
|
|
|
|
// test for add_param_group() when one or more params existing in another
|
|
// param_group are passed in the new param group to be added
|
|
ASSERT_THROWS_WITH(
|
|
optimizer.add_param_group(OptimizerParamGroup(params)),
|
|
"some parameters appear in more than one parameter group");
|
|
|
|
// test for state() with non-const reference return
|
|
auto& state_ = static_cast<AdagradParamState&>(
|
|
*(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
|
|
state_.step(state_.step() + 1);
|
|
|
|
const auto& optimizer_ = Adagrad(params, options);
|
|
optimizer_.defaults();
|
|
// test for param_groups() with const reference return
|
|
(void)optimizer_.param_groups();
|
|
// test for state() with const reference return
|
|
optimizer_.state();
|
|
}
|
|
|
|
#define OLD_INTERFACE_WARNING_CHECK(func) \
|
|
{ \
|
|
torch::test::WarningCapture warnings; \
|
|
func; \
|
|
ASSERT_EQ( \
|
|
torch::test::count_substr_occurrences( \
|
|
warnings.str(), "will be removed"), \
|
|
1); \
|
|
}
|
|
|
|
struct MyOptimizerOptions
|
|
: public OptimizerCloneableOptions<MyOptimizerOptions> {
|
|
MyOptimizerOptions(double lr = 1.0) : lr_(lr) {}
|
|
TORCH_ARG(double, lr) = 1.0;
|
|
};
|
|
|
|
TEST(OptimTest, OldInterface) {
|
|
struct MyOptimizer : Optimizer {
|
|
using Optimizer::Optimizer;
|
|
torch::Tensor step(LossClosure closure = nullptr) override {
|
|
return {};
|
|
}
|
|
explicit MyOptimizer(
|
|
std::vector<at::Tensor> params,
|
|
const MyOptimizerOptions& defaults = {})
|
|
: Optimizer(
|
|
std::move(params),
|
|
std::make_unique<MyOptimizerOptions>(defaults)) {}
|
|
};
|
|
std::vector<torch::Tensor> parameters = {
|
|
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
|
|
{
|
|
MyOptimizer optimizer(parameters);
|
|
size_t size = 0;
|
|
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
|
|
ASSERT_EQ(size, parameters.size());
|
|
}
|
|
{
|
|
std::vector<at::Tensor> params;
|
|
MyOptimizer optimizer(params);
|
|
|
|
size_t size = 0;
|
|
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
|
|
ASSERT_EQ(size, 0);
|
|
|
|
OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
|
|
|
|
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
|
|
ASSERT_EQ(size, parameters.size());
|
|
|
|
std::vector<torch::Tensor> params_;
|
|
OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
|
|
for (const auto p : c10::irange(size)) {
|
|
ASSERT_TRUE(params_[p].allclose(parameters[p]));
|
|
}
|
|
}
|
|
{
|
|
Linear linear(3, 4);
|
|
MyOptimizer optimizer(linear->parameters());
|
|
|
|
size_t size = 0;
|
|
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
|
|
ASSERT_EQ(size, linear->parameters().size());
|
|
}
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_SGD) {
|
|
ASSERT_TRUE(test_optimizer_xor<SGD>(
|
|
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_LBFGS) {
|
|
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
|
|
ASSERT_TRUE(test_optimizer_xor<LBFGS>(
|
|
LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_Adagrad) {
|
|
ASSERT_TRUE(test_optimizer_xor<Adagrad>(
|
|
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_RMSprop) {
|
|
ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
|
|
ASSERT_TRUE(test_optimizer_xor<RMSprop>(
|
|
RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_Adam) {
|
|
ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
|
|
ASSERT_TRUE(test_optimizer_xor<Adam>(
|
|
AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_Adam) {
|
|
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
|
|
check_exact_values<Adam>(
|
|
AdamOptions(1.0).weight_decay(1e-2),
|
|
expected_parameters::Adam_with_weight_decay());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
|
|
check_exact_values<Adam>(
|
|
AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
|
|
expected_parameters::Adam_with_weight_decay_and_amsgrad());
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_AdamW) {
|
|
ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
|
|
}
|
|
|
|
TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
|
|
ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true)));
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_AdamW) {
|
|
check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
|
|
check_exact_values<AdamW>(
|
|
AdamWOptions(1.0).weight_decay(0),
|
|
expected_parameters::AdamW_without_weight_decay());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
|
|
check_exact_values<AdamW>(
|
|
AdamWOptions(1.0).amsgrad(true),
|
|
expected_parameters::AdamW_with_amsgrad());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
|
|
check_exact_values<Adagrad>(
|
|
AdagradOptions(1.0), expected_parameters::Adagrad());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
|
|
check_exact_values<Adagrad>(
|
|
AdagradOptions(1.0).weight_decay(1e-2),
|
|
expected_parameters::Adagrad_with_weight_decay());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
|
|
check_exact_values<Adagrad>(
|
|
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
|
|
expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
|
|
check_exact_values<RMSprop>(
|
|
RMSpropOptions(0.1), expected_parameters::RMSprop());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
|
|
check_exact_values<RMSprop>(
|
|
RMSpropOptions(0.1).weight_decay(1e-2),
|
|
expected_parameters::RMSprop_with_weight_decay());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
|
|
check_exact_values<RMSprop>(
|
|
RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
|
|
expected_parameters::RMSprop_with_weight_decay_and_centered());
|
|
}
|
|
|
|
TEST(
|
|
OptimTest,
|
|
ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
|
|
check_exact_values<RMSprop>(
|
|
RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
|
|
expected_parameters::
|
|
RMSprop_with_weight_decay_and_centered_and_momentum());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_SGD) {
|
|
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
|
|
check_exact_values<SGD>(
|
|
SGDOptions(0.1).weight_decay(1e-2),
|
|
expected_parameters::SGD_with_weight_decay());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
|
|
check_exact_values<SGD>(
|
|
SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
|
|
expected_parameters::SGD_with_weight_decay_and_momentum());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
|
|
check_exact_values<SGD>(
|
|
SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
|
|
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
|
|
check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS());
|
|
}
|
|
|
|
TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
|
|
check_exact_values<LBFGS>(
|
|
LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
|
|
expected_parameters::LBFGS_with_line_search());
|
|
}
|
|
|
|
TEST(OptimTest, ZeroGrad) {
|
|
torch::manual_seed(0);
|
|
|
|
Linear model(2, 8);
|
|
SGD optimizer(model->parameters(), 0.1);
|
|
|
|
for (const auto& parameter : model->parameters()) {
|
|
ASSERT_FALSE(parameter.grad().defined());
|
|
}
|
|
|
|
auto output = model->forward(torch::ones({5, 2}));
|
|
auto loss = output.sum();
|
|
loss.backward();
|
|
|
|
for (const auto& parameter : model->parameters()) {
|
|
ASSERT_TRUE(parameter.grad().defined());
|
|
ASSERT_GT(parameter.grad().sum().item<float>(), 0);
|
|
}
|
|
|
|
optimizer.zero_grad();
|
|
|
|
for (const auto& parameter : model->parameters()) {
|
|
ASSERT_FALSE(parameter.grad().defined());
|
|
}
|
|
}
|
|
|
|
TEST(OptimTest, ExternalVectorOfParameters) {
|
|
torch::manual_seed(0);
|
|
|
|
std::vector<torch::Tensor> parameters = {
|
|
torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
|
|
std::vector<torch::Tensor> original_parameters = {
|
|
parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
|
|
|
|
// Set all gradients to one
|
|
for (auto& parameter : parameters) {
|
|
parameter.mutable_grad() = torch::ones_like(parameter);
|
|
}
|
|
|
|
SGD optimizer(parameters, 1.0);
|
|
|
|
optimizer.step();
|
|
|
|
ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
|
|
ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
|
|
ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
|
|
}
|
|
|
|
TEST(OptimTest, AddParameter_LBFGS) {
|
|
torch::manual_seed(0);
|
|
|
|
std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
|
|
std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
|
|
|
|
// Set all gradients to one
|
|
for (auto& parameter : parameters) {
|
|
parameter.mutable_grad() = torch::ones_like(parameter);
|
|
}
|
|
|
|
LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
|
|
OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
|
|
|
|
optimizer.step([]() { return torch::tensor(1); });
|
|
|
|
// REQUIRE this doesn't throw
|
|
}
|
|
|
|
// Check whether the learning rate of the parameter groups in the optimizer are
|
|
// the same as the expected learning rates given in the epoch:learning rate map
|
|
static void check_lr_change(
|
|
Optimizer& optimizer,
|
|
LRScheduler& lr_scheduler,
|
|
std::map<unsigned, double> expected_epoch_lrs) {
|
|
// Find maximum epoch in map
|
|
unsigned kIterations = std::max_element(
|
|
expected_epoch_lrs.begin(),
|
|
expected_epoch_lrs.end(),
|
|
[](const std::pair<unsigned, double>& a,
|
|
const std::pair<unsigned, double>& b) -> bool {
|
|
return a.second > b.second;
|
|
})
|
|
->first;
|
|
|
|
for (unsigned i = 0; i <= kIterations; i++) {
|
|
const auto epoch_iter = expected_epoch_lrs.find(i);
|
|
if (epoch_iter != expected_epoch_lrs.end()) {
|
|
// Compare the similarity of the two floating point learning rates
|
|
ASSERT_TRUE(
|
|
fabs(
|
|
epoch_iter->second -
|
|
optimizer.param_groups()[0].options().get_lr()) <
|
|
std::numeric_limits<double>::epsilon());
|
|
}
|
|
optimizer.step();
|
|
lr_scheduler.step();
|
|
}
|
|
}
|
|
|
|
// Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
|
|
// which does not inherit from LRScheduler and requires a metrics
|
|
// input to step().
|
|
static void check_lr_change_for_reduce_on_plateau(
|
|
Optimizer& optimizer,
|
|
ReduceLROnPlateauScheduler& lr_scheduler,
|
|
std::map<unsigned, double> expected_epoch_lrs) {
|
|
// Find maximum epoch in map
|
|
unsigned kIterations = std::max_element(
|
|
expected_epoch_lrs.begin(),
|
|
expected_epoch_lrs.end(),
|
|
[](const std::pair<unsigned, double>& a,
|
|
const std::pair<unsigned, double>& b) -> bool {
|
|
return a.second > b.second;
|
|
})
|
|
->first;
|
|
|
|
for (unsigned i = 0; i <= kIterations; i++) {
|
|
const auto epoch_iter = expected_epoch_lrs.find(i);
|
|
if (epoch_iter != expected_epoch_lrs.end()) {
|
|
// Compare the similarity of the two floating point learning rates
|
|
ASSERT_TRUE(
|
|
fabs(
|
|
epoch_iter->second -
|
|
optimizer.param_groups()[0].options().get_lr()) <
|
|
std::numeric_limits<double>::epsilon());
|
|
}
|
|
optimizer.step();
|
|
lr_scheduler.step(5.0);
|
|
}
|
|
}
|
|
|
|
TEST(OptimTest, CheckLRChange_StepLR_Adam) {
|
|
torch::Tensor parameters = torch::zeros({1});
|
|
auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
|
|
|
|
const unsigned step_size = 20;
|
|
const double gamma = 0.5;
|
|
StepLR step_lr_scheduler(optimizer, step_size, gamma);
|
|
|
|
// The learning rate should have halved at epoch 20
|
|
const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
|
|
|
|
check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
|
|
}
|
|
|
|
TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
|
|
torch::Tensor parameters = torch::zeros({1});
|
|
auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
|
|
const float factor = 0.5;
|
|
const int patience = 20;
|
|
ReduceLROnPlateauScheduler reduce_lr_on_plateau_scheduler(
|
|
optimizer,
|
|
ReduceLROnPlateauScheduler::SchedulerMode::min,
|
|
factor,
|
|
patience);
|
|
|
|
// The learning rate should have halved at epoch 20
|
|
const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
|
|
|
|
check_lr_change_for_reduce_on_plateau(
|
|
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
|
|
}
|
|
// Tests for Issue 141884: Parameter group inheritance functionality
|
|
// Validates that partial options in parameter groups correctly inherit
|
|
// defaults from the optimizer while preserving explicitly set values
|
|
TEST(OptimTest, MergeWithDefaultOptions_Adam) {
|
|
// Create tensors for parameter groups
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
// Create param groups with partial options
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Group 1: Only weight_decay specified, should inherit lr, betas, eps,
|
|
// amsgrad
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor1},
|
|
std::make_unique<AdamOptions>(AdamOptions().weight_decay(0.11)));
|
|
|
|
// Group 2: Only eps specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor2},
|
|
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
|
|
|
// Create optimizer with specific defaults
|
|
AdamOptions defaults;
|
|
defaults.lr(0.002)
|
|
.betas(std::make_tuple(0.8, 0.88))
|
|
.eps(1e-12)
|
|
.weight_decay(0.05)
|
|
.amsgrad(true);
|
|
|
|
Adam optimizer(param_groups, defaults);
|
|
|
|
// Check Group 1: weight_decay preserved, others inherited
|
|
auto& group1_opts =
|
|
static_cast<AdamOptions&>(optimizer.param_groups()[0].options());
|
|
ASSERT_EQ(group1_opts.lr(), 0.002); // Inherited
|
|
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
|
ASSERT_EQ(group1_opts.eps(), 1e-12); // Inherited
|
|
ASSERT_EQ(group1_opts.weight_decay(), 0.11); // Preserved
|
|
ASSERT_TRUE(group1_opts.amsgrad()); // Inherited
|
|
|
|
// Check Group 2: eps preserved, others inherited
|
|
auto& group2_opts =
|
|
static_cast<AdamOptions&>(optimizer.param_groups()[1].options());
|
|
ASSERT_EQ(group2_opts.lr(), 0.002); // Inherited
|
|
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
|
ASSERT_EQ(group2_opts.eps(), 1e-6); // Preserved
|
|
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
|
ASSERT_TRUE(group2_opts.amsgrad()); // Inherited
|
|
}
|
|
|
|
TEST(OptimTest, MergeWithDefaultOptions_SGD) {
|
|
// Create tensors for parameter groups
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
// Create param groups with partial options
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Group 1: Only lr and weight_decay specified, should inherit momentum,
|
|
// dampening, nesterov
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor1},
|
|
std::make_unique<SGDOptions>(SGDOptions(0.01).weight_decay(0.22)));
|
|
|
|
// Group 2: Only lr specified, should inherit all others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor2},
|
|
std::make_unique<SGDOptions>(SGDOptions(0.02)));
|
|
|
|
// Create optimizer with specific defaults
|
|
SGDOptions defaults(0.001); // lr should be overridden by param groups
|
|
defaults.momentum(0.9)
|
|
.dampening(0.0) // Must be 0 for Nesterov
|
|
.weight_decay(0.05)
|
|
.nesterov(true);
|
|
|
|
SGD optimizer(param_groups, defaults);
|
|
|
|
// Check Group 1: lr and weight_decay preserved, others inherited
|
|
auto& group1_opts =
|
|
static_cast<SGDOptions&>(optimizer.param_groups()[0].options());
|
|
ASSERT_EQ(group1_opts.lr(), 0.01); // Preserved
|
|
ASSERT_EQ(group1_opts.momentum(), 0.9); // Inherited
|
|
ASSERT_EQ(group1_opts.dampening(), 0.0); // Inherited
|
|
ASSERT_EQ(group1_opts.weight_decay(), 0.22); // Preserved
|
|
ASSERT_TRUE(group1_opts.nesterov()); // Inherited
|
|
|
|
// Check Group 2: lr preserved, others inherited
|
|
auto& group2_opts =
|
|
static_cast<SGDOptions&>(optimizer.param_groups()[1].options());
|
|
ASSERT_EQ(group2_opts.lr(), 0.02); // Preserved
|
|
ASSERT_EQ(group2_opts.momentum(), 0.9); // Inherited
|
|
ASSERT_EQ(group2_opts.dampening(), 0.0); // Inherited
|
|
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
|
ASSERT_TRUE(group2_opts.nesterov()); // Inherited
|
|
}
|
|
|
|
TEST(OptimTest, MergeWithDefaultOptions_AdamW) {
|
|
// Create tensors for parameter groups
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
// Create param groups with partial options
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Group 1: Only eps specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor1},
|
|
std::make_unique<AdamWOptions>(AdamWOptions().eps(1e-6)));
|
|
|
|
// Group 2: Only betas specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor2},
|
|
std::make_unique<AdamWOptions>(
|
|
AdamWOptions().betas(std::make_tuple(0.95, 0.999))));
|
|
|
|
// Create optimizer with specific defaults
|
|
AdamWOptions defaults;
|
|
defaults.lr(0.003)
|
|
.betas(std::make_tuple(0.9, 0.98))
|
|
.eps(1e-8)
|
|
.weight_decay(0.02)
|
|
.amsgrad(false);
|
|
|
|
AdamW optimizer(param_groups, defaults);
|
|
|
|
// Check Group 1: eps preserved, others inherited
|
|
auto& group1_opts =
|
|
static_cast<AdamWOptions&>(optimizer.param_groups()[0].options());
|
|
ASSERT_EQ(group1_opts.lr(), 0.003); // Inherited
|
|
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.9, 0.98)); // Inherited
|
|
ASSERT_EQ(group1_opts.eps(), 1e-6); // Preserved
|
|
ASSERT_EQ(group1_opts.weight_decay(), 0.02); // Inherited
|
|
ASSERT_FALSE(group1_opts.amsgrad()); // Inherited
|
|
|
|
// Check Group 2: betas preserved, others inherited
|
|
auto& group2_opts =
|
|
static_cast<AdamWOptions&>(optimizer.param_groups()[1].options());
|
|
ASSERT_EQ(group2_opts.lr(), 0.003); // Inherited
|
|
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.95, 0.999)); // Preserved
|
|
ASSERT_EQ(group2_opts.eps(), 1e-8); // Inherited
|
|
ASSERT_EQ(group2_opts.weight_decay(), 0.02); // Inherited
|
|
ASSERT_FALSE(group2_opts.amsgrad()); // Inherited
|
|
}
|
|
|
|
TEST(OptimTest, MergeWithDefaultOptions_Adagrad) {
|
|
// Create tensors for parameter groups
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
// Create param groups with partial options
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Group 1: Only lr_decay specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor1},
|
|
std::make_unique<AdagradOptions>(AdagradOptions().lr_decay(0.001)));
|
|
|
|
// Group 2: Only initial_accumulator_value specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor2},
|
|
std::make_unique<AdagradOptions>(
|
|
AdagradOptions().initial_accumulator_value(0.5)));
|
|
|
|
// Create optimizer with specific defaults
|
|
AdagradOptions defaults;
|
|
defaults.lr(0.04)
|
|
.lr_decay(0.002)
|
|
.weight_decay(0.03)
|
|
.initial_accumulator_value(0.1)
|
|
.eps(1e-11);
|
|
|
|
Adagrad optimizer(param_groups, defaults);
|
|
|
|
// Check Group 1: lr_decay preserved, others inherited
|
|
auto& group1_opts =
|
|
static_cast<AdagradOptions&>(optimizer.param_groups()[0].options());
|
|
ASSERT_EQ(group1_opts.lr(), 0.04); // Inherited
|
|
ASSERT_EQ(group1_opts.lr_decay(), 0.001); // Preserved
|
|
ASSERT_EQ(group1_opts.weight_decay(), 0.03); // Inherited
|
|
ASSERT_EQ(group1_opts.initial_accumulator_value(), 0.1); // Inherited
|
|
ASSERT_EQ(group1_opts.eps(), 1e-11); // Inherited
|
|
|
|
// Check Group 2: initial_accumulator_value preserved, others inherited
|
|
auto& group2_opts =
|
|
static_cast<AdagradOptions&>(optimizer.param_groups()[1].options());
|
|
ASSERT_EQ(group2_opts.lr(), 0.04); // Inherited
|
|
ASSERT_EQ(group2_opts.lr_decay(), 0.002); // Inherited
|
|
ASSERT_EQ(group2_opts.weight_decay(), 0.03); // Inherited
|
|
ASSERT_EQ(group2_opts.initial_accumulator_value(), 0.5); // Preserved
|
|
ASSERT_EQ(group2_opts.eps(), 1e-11); // Inherited
|
|
}
|
|
|
|
TEST(OptimTest, MergeWithDefaultOptions_RMSprop) {
|
|
// Create tensors for parameter groups
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
// Create param groups with partial options
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Group 1: Only alpha specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor1},
|
|
std::make_unique<RMSpropOptions>(RMSpropOptions().alpha(0.95)));
|
|
|
|
// Group 2: Only momentum and centered specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor2},
|
|
std::make_unique<RMSpropOptions>(
|
|
RMSpropOptions().momentum(0.8).centered(true)));
|
|
|
|
// Create optimizer with specific defaults
|
|
RMSpropOptions defaults;
|
|
defaults.lr(0.015)
|
|
.alpha(0.98)
|
|
.eps(1e-9)
|
|
.weight_decay(0.01)
|
|
.momentum(0.7)
|
|
.centered(false);
|
|
|
|
RMSprop optimizer(param_groups, defaults);
|
|
|
|
// Check Group 1: alpha preserved, others inherited
|
|
auto& group1_opts =
|
|
static_cast<RMSpropOptions&>(optimizer.param_groups()[0].options());
|
|
ASSERT_EQ(group1_opts.lr(), 0.015); // Inherited
|
|
ASSERT_EQ(group1_opts.alpha(), 0.95); // Preserved
|
|
ASSERT_EQ(group1_opts.eps(), 1e-9); // Inherited
|
|
ASSERT_EQ(group1_opts.weight_decay(), 0.01); // Inherited
|
|
ASSERT_EQ(group1_opts.momentum(), 0.7); // Inherited
|
|
ASSERT_FALSE(group1_opts.centered()); // Inherited
|
|
|
|
// Check Group 2: momentum and centered preserved, others inherited
|
|
auto& group2_opts =
|
|
static_cast<RMSpropOptions&>(optimizer.param_groups()[1].options());
|
|
ASSERT_EQ(group2_opts.lr(), 0.015); // Inherited
|
|
ASSERT_EQ(group2_opts.alpha(), 0.98); // Inherited
|
|
ASSERT_EQ(group2_opts.eps(), 1e-9); // Inherited
|
|
ASSERT_EQ(group2_opts.weight_decay(), 0.01); // Inherited
|
|
ASSERT_EQ(group2_opts.momentum(), 0.8); // Preserved
|
|
ASSERT_TRUE(group2_opts.centered()); // Preserved
|
|
}
|
|
|
|
TEST(OptimTest, MergeWithDefaultOptions_LBFGS) {
|
|
// Create tensors for single parameter group (LBFGS limitation)
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
// Create param group with partial options
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Single group: Only max_iter specified, should inherit others
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{
|
|
tensor1, tensor2}, // Combine tensors in single group
|
|
std::make_unique<LBFGSOptions>(LBFGSOptions().max_iter(15)));
|
|
|
|
// Create optimizer with specific defaults
|
|
LBFGSOptions defaults;
|
|
defaults.lr(0.8)
|
|
.max_iter(25)
|
|
.max_eval(31) // Use same value that appears to be auto-calculated
|
|
.tolerance_grad(1e-5)
|
|
.tolerance_change(1e-8)
|
|
.history_size(80)
|
|
.line_search_fn("strong_wolfe");
|
|
|
|
LBFGS optimizer(param_groups, defaults);
|
|
|
|
// Check Group: max_iter preserved, others inherited
|
|
auto& group_opts =
|
|
static_cast<LBFGSOptions&>(optimizer.param_groups()[0].options());
|
|
ASSERT_EQ(group_opts.lr(), 0.8); // Inherited
|
|
ASSERT_EQ(group_opts.max_iter(), 15); // Preserved
|
|
ASSERT_EQ(group_opts.max_eval(), 31); // Inherited
|
|
ASSERT_EQ(group_opts.tolerance_grad(), 1e-5); // Inherited
|
|
ASSERT_EQ(group_opts.tolerance_change(), 1e-8); // Inherited
|
|
ASSERT_EQ(group_opts.history_size(), 80); // Inherited
|
|
ASSERT_EQ(group_opts.line_search_fn(), "strong_wolfe"); // Inherited
|
|
}
|
|
|
|
TEST(OptimTest, MergeWithDefaultOptions_NoOptionsInheritance) {
|
|
// Test that param groups without options get full defaults
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Groups with no options - should inherit everything
|
|
param_groups.emplace_back(std::vector<torch::Tensor>{tensor1});
|
|
param_groups.emplace_back(std::vector<torch::Tensor>{tensor2});
|
|
|
|
// Create optimizer with specific defaults
|
|
AdamOptions defaults;
|
|
defaults.lr(0.005)
|
|
.betas(std::make_tuple(0.85, 0.95))
|
|
.eps(1e-7)
|
|
.weight_decay(0.08)
|
|
.amsgrad(true);
|
|
|
|
Adam optimizer(param_groups, defaults);
|
|
|
|
// Both groups should have exactly the default options
|
|
for (int i = 0; i < 2; i++) {
|
|
auto& group_opts =
|
|
static_cast<AdamOptions&>(optimizer.param_groups()[i].options());
|
|
ASSERT_EQ(group_opts.lr(), 0.005);
|
|
ASSERT_EQ(group_opts.betas(), std::make_tuple(0.85, 0.95));
|
|
ASSERT_EQ(group_opts.eps(), 1e-7);
|
|
ASSERT_EQ(group_opts.weight_decay(), 0.08);
|
|
ASSERT_TRUE(group_opts.amsgrad());
|
|
}
|
|
}
|
|
|
|
// Test that field tracking survives serialization/deserialization cycles
|
|
TEST(OptimTest, SerializationPreservesFieldTracking_Adam) {
|
|
// Create tensors for parameter groups
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
// Create param groups with partial options using fluent API (marks fields as
|
|
// explicit)
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
|
|
// Group 1: Only weight_decay and amsgrad explicitly set via fluent API
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor1},
|
|
std::make_unique<AdamOptions>(
|
|
AdamOptions().weight_decay(0.11).amsgrad(true)));
|
|
|
|
// Group 2: Only eps explicitly set via fluent API
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor2},
|
|
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
|
|
|
// Create optimizer with specific defaults
|
|
AdamOptions defaults;
|
|
defaults.lr(0.002)
|
|
.betas(std::make_tuple(0.8, 0.88))
|
|
.eps(1e-12)
|
|
.weight_decay(0.05)
|
|
.amsgrad(false);
|
|
|
|
Adam original_optimizer(param_groups, defaults);
|
|
|
|
// Capture original state for comparison
|
|
auto& orig_group1_opts =
|
|
static_cast<AdamOptions&>(original_optimizer.param_groups()[0].options());
|
|
auto& orig_group2_opts =
|
|
static_cast<AdamOptions&>(original_optimizer.param_groups()[1].options());
|
|
|
|
// Verify original state (sanity check)
|
|
ASSERT_NEAR(orig_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
|
ASSERT_TRUE(orig_group1_opts.amsgrad()); // Explicitly set
|
|
ASSERT_NEAR(orig_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
|
ASSERT_NEAR(orig_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
|
ASSERT_NEAR(orig_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
|
|
|
// Test serialization of the options objects (where field tracking lives)
|
|
std::stringstream ss1, ss2;
|
|
|
|
// Serialize the parameter group options
|
|
{
|
|
torch::serialize::OutputArchive archive;
|
|
orig_group1_opts.serialize(archive);
|
|
archive.save_to(ss1);
|
|
}
|
|
{
|
|
torch::serialize::OutputArchive archive;
|
|
orig_group2_opts.serialize(archive);
|
|
archive.save_to(ss2);
|
|
}
|
|
|
|
// Create new options objects and deserialize
|
|
AdamOptions loaded_group1_opts;
|
|
AdamOptions loaded_group2_opts;
|
|
|
|
{
|
|
torch::serialize::InputArchive archive;
|
|
archive.load_from(ss1);
|
|
loaded_group1_opts.serialize(archive);
|
|
}
|
|
{
|
|
torch::serialize::InputArchive archive;
|
|
archive.load_from(ss2);
|
|
loaded_group2_opts.serialize(archive);
|
|
}
|
|
|
|
// Verify that all parameter values are preserved after deserialization
|
|
|
|
// Group 1: weight_decay and amsgrad should be preserved as explicitly set,
|
|
// others inherited
|
|
ASSERT_NEAR(loaded_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
|
ASSERT_EQ(
|
|
loaded_group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
|
ASSERT_NEAR(loaded_group1_opts.eps(), 1e-12, 1e-15); // Inherited
|
|
ASSERT_NEAR(loaded_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
|
ASSERT_TRUE(loaded_group1_opts.amsgrad()); // Explicitly set
|
|
|
|
// Group 2: eps should be preserved as explicitly set, others inherited
|
|
ASSERT_NEAR(loaded_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
|
ASSERT_EQ(
|
|
loaded_group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
|
ASSERT_NEAR(loaded_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
|
ASSERT_NEAR(loaded_group2_opts.weight_decay(), 0.05, 1e-6); // Inherited
|
|
ASSERT_FALSE(loaded_group2_opts.amsgrad()); // Inherited
|
|
|
|
// CRITICAL: Test that field tracking is preserved after serialization
|
|
// Create a new optimizer using the deserialized options to test inheritance
|
|
auto tensor3 = torch::randn({2, 2}).requires_grad_(true);
|
|
auto tensor4 = torch::randn({3, 3}).requires_grad_(true);
|
|
|
|
std::vector<OptimizerParamGroup> test_param_groups;
|
|
test_param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor3},
|
|
std::make_unique<AdamOptions>(loaded_group1_opts));
|
|
test_param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor4},
|
|
std::make_unique<AdamOptions>(loaded_group2_opts));
|
|
|
|
Adam test_optimizer(test_param_groups, defaults);
|
|
|
|
// The field tracking should work correctly for inheritance
|
|
auto& final_group1_opts =
|
|
static_cast<AdamOptions&>(test_optimizer.param_groups()[0].options());
|
|
auto& final_group2_opts =
|
|
static_cast<AdamOptions&>(test_optimizer.param_groups()[1].options());
|
|
|
|
// Group 1: weight_decay and amsgrad should still be preserved as explicitly
|
|
// set
|
|
ASSERT_NEAR(
|
|
final_group1_opts.weight_decay(),
|
|
0.11,
|
|
1e-6); // Explicitly set (preserved)
|
|
ASSERT_TRUE(final_group1_opts.amsgrad()); // Explicitly set (preserved)
|
|
ASSERT_NEAR(final_group1_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
|
|
|
// Group 2: eps should still be preserved as explicitly set
|
|
ASSERT_NEAR(
|
|
final_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set (preserved)
|
|
ASSERT_NEAR(final_group2_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
|
}
|
|
|
|
// Test serialization with SGD (different parameter types)
|
|
TEST(OptimTest, SerializationPreservesFieldTracking_SGD) {
|
|
// Create tensors
|
|
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
|
|
|
// Create param group with partial options using fluent API
|
|
std::vector<OptimizerParamGroup> param_groups;
|
|
param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor1},
|
|
std::make_unique<SGDOptions>(
|
|
SGDOptions(0.01).weight_decay(0.22).nesterov(true)));
|
|
|
|
// Create optimizer with defaults
|
|
SGDOptions defaults(0.001);
|
|
defaults.momentum(0.9).dampening(0.0).weight_decay(0.05).nesterov(false);
|
|
|
|
SGD original_optimizer(param_groups, defaults);
|
|
|
|
// Test serialization of the SGD options (where field tracking lives)
|
|
auto& original_opts =
|
|
static_cast<SGDOptions&>(original_optimizer.param_groups()[0].options());
|
|
|
|
std::stringstream ss;
|
|
{
|
|
torch::serialize::OutputArchive archive;
|
|
original_opts.serialize(archive);
|
|
archive.save_to(ss);
|
|
}
|
|
|
|
SGDOptions loaded_opts(0.0); // Dummy initial value
|
|
{
|
|
torch::serialize::InputArchive archive;
|
|
archive.load_from(ss);
|
|
loaded_opts.serialize(archive);
|
|
}
|
|
ASSERT_NEAR(loaded_opts.lr(), 0.01, 1e-6); // Explicitly set
|
|
ASSERT_NEAR(loaded_opts.momentum(), 0.9, 1e-6); // Inherited
|
|
ASSERT_NEAR(loaded_opts.dampening(), 0.0, 1e-6); // Inherited
|
|
ASSERT_NEAR(loaded_opts.weight_decay(), 0.22, 1e-6); // Explicitly set
|
|
ASSERT_TRUE(loaded_opts.nesterov()); // Explicitly set
|
|
|
|
// Test that field tracking still works after deserialization by creating new
|
|
// optimizer
|
|
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
|
std::vector<OptimizerParamGroup> test_param_groups;
|
|
test_param_groups.emplace_back(
|
|
std::vector<torch::Tensor>{tensor2},
|
|
std::make_unique<SGDOptions>(loaded_opts));
|
|
|
|
SGD test_optimizer(test_param_groups, defaults);
|
|
|
|
auto& final_opts =
|
|
static_cast<SGDOptions&>(test_optimizer.param_groups()[0].options());
|
|
ASSERT_NEAR(final_opts.lr(), 0.01, 1e-6); // Explicitly set (preserved)
|
|
ASSERT_NEAR(
|
|
final_opts.weight_decay(), 0.22, 1e-6); // Explicitly set (preserved)
|
|
ASSERT_TRUE(final_opts.nesterov()); // Explicitly set (preserved)
|
|
ASSERT_NEAR(final_opts.momentum(), 0.9, 1e-6); // Inherited from defaults
|
|
ASSERT_NEAR(final_opts.dampening(), 0.0, 1e-6); // Inherited from defaults
|
|
}
|