mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "C++ API handle optimizer defaults (#161825)"
This reverts commit f33201729416ed17467228e80b04d01d4d02b5f3. Reverted https://github.com/pytorch/pytorch/pull/161825 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/161825#issuecomment-3391506427))
This commit is contained in:
@ -564,508 +564,3 @@ TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
|
||||
check_lr_change_for_reduce_on_plateau(
|
||||
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
|
||||
}
|
||||
// Tests for Issue 141884: Parameter group inheritance functionality
|
||||
// Validates that partial options in parameter groups correctly inherit
|
||||
// defaults from the optimizer while preserving explicitly set values
|
||||
TEST(OptimTest, MergeWithDefaultOptions_Adam) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only weight_decay specified, should inherit lr, betas, eps,
|
||||
// amsgrad
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamOptions>(AdamOptions().weight_decay(0.11)));
|
||||
|
||||
// Group 2: Only eps specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.002)
|
||||
.betas(std::make_tuple(0.8, 0.88))
|
||||
.eps(1e-12)
|
||||
.weight_decay(0.05)
|
||||
.amsgrad(true);
|
||||
|
||||
Adam optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: weight_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.002); // Inherited
|
||||
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-12); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.11); // Preserved
|
||||
ASSERT_TRUE(group1_opts.amsgrad()); // Inherited
|
||||
|
||||
// Check Group 2: eps preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.002); // Inherited
|
||||
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-6); // Preserved
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
||||
ASSERT_TRUE(group2_opts.amsgrad()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_SGD) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only lr and weight_decay specified, should inherit momentum,
|
||||
// dampening, nesterov
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<SGDOptions>(SGDOptions(0.01).weight_decay(0.22)));
|
||||
|
||||
// Group 2: Only lr specified, should inherit all others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<SGDOptions>(SGDOptions(0.02)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
SGDOptions defaults(0.001); // lr should be overridden by param groups
|
||||
defaults.momentum(0.9)
|
||||
.dampening(0.0) // Must be 0 for Nesterov
|
||||
.weight_decay(0.05)
|
||||
.nesterov(true);
|
||||
|
||||
SGD optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: lr and weight_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<SGDOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.01); // Preserved
|
||||
ASSERT_EQ(group1_opts.momentum(), 0.9); // Inherited
|
||||
ASSERT_EQ(group1_opts.dampening(), 0.0); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.22); // Preserved
|
||||
ASSERT_TRUE(group1_opts.nesterov()); // Inherited
|
||||
|
||||
// Check Group 2: lr preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<SGDOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.02); // Preserved
|
||||
ASSERT_EQ(group2_opts.momentum(), 0.9); // Inherited
|
||||
ASSERT_EQ(group2_opts.dampening(), 0.0); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
||||
ASSERT_TRUE(group2_opts.nesterov()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_AdamW) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only eps specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamWOptions>(AdamWOptions().eps(1e-6)));
|
||||
|
||||
// Group 2: Only betas specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamWOptions>(
|
||||
AdamWOptions().betas(std::make_tuple(0.95, 0.999))));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamWOptions defaults;
|
||||
defaults.lr(0.003)
|
||||
.betas(std::make_tuple(0.9, 0.98))
|
||||
.eps(1e-8)
|
||||
.weight_decay(0.02)
|
||||
.amsgrad(false);
|
||||
|
||||
AdamW optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: eps preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdamWOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.003); // Inherited
|
||||
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.9, 0.98)); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-6); // Preserved
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.02); // Inherited
|
||||
ASSERT_FALSE(group1_opts.amsgrad()); // Inherited
|
||||
|
||||
// Check Group 2: betas preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdamWOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.003); // Inherited
|
||||
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.95, 0.999)); // Preserved
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-8); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.02); // Inherited
|
||||
ASSERT_FALSE(group2_opts.amsgrad()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_Adagrad) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only lr_decay specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdagradOptions>(AdagradOptions().lr_decay(0.001)));
|
||||
|
||||
// Group 2: Only initial_accumulator_value specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdagradOptions>(
|
||||
AdagradOptions().initial_accumulator_value(0.5)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdagradOptions defaults;
|
||||
defaults.lr(0.04)
|
||||
.lr_decay(0.002)
|
||||
.weight_decay(0.03)
|
||||
.initial_accumulator_value(0.1)
|
||||
.eps(1e-11);
|
||||
|
||||
Adagrad optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: lr_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdagradOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.04); // Inherited
|
||||
ASSERT_EQ(group1_opts.lr_decay(), 0.001); // Preserved
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.03); // Inherited
|
||||
ASSERT_EQ(group1_opts.initial_accumulator_value(), 0.1); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-11); // Inherited
|
||||
|
||||
// Check Group 2: initial_accumulator_value preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdagradOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.04); // Inherited
|
||||
ASSERT_EQ(group2_opts.lr_decay(), 0.002); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.03); // Inherited
|
||||
ASSERT_EQ(group2_opts.initial_accumulator_value(), 0.5); // Preserved
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-11); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_RMSprop) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only alpha specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<RMSpropOptions>(RMSpropOptions().alpha(0.95)));
|
||||
|
||||
// Group 2: Only momentum and centered specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<RMSpropOptions>(
|
||||
RMSpropOptions().momentum(0.8).centered(true)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
RMSpropOptions defaults;
|
||||
defaults.lr(0.015)
|
||||
.alpha(0.98)
|
||||
.eps(1e-9)
|
||||
.weight_decay(0.01)
|
||||
.momentum(0.7)
|
||||
.centered(false);
|
||||
|
||||
RMSprop optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: alpha preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<RMSpropOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.015); // Inherited
|
||||
ASSERT_EQ(group1_opts.alpha(), 0.95); // Preserved
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-9); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.01); // Inherited
|
||||
ASSERT_EQ(group1_opts.momentum(), 0.7); // Inherited
|
||||
ASSERT_FALSE(group1_opts.centered()); // Inherited
|
||||
|
||||
// Check Group 2: momentum and centered preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<RMSpropOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.015); // Inherited
|
||||
ASSERT_EQ(group2_opts.alpha(), 0.98); // Inherited
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-9); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.01); // Inherited
|
||||
ASSERT_EQ(group2_opts.momentum(), 0.8); // Preserved
|
||||
ASSERT_TRUE(group2_opts.centered()); // Preserved
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_LBFGS) {
|
||||
// Create tensors for single parameter group (LBFGS limitation)
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param group with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Single group: Only max_iter specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{
|
||||
tensor1, tensor2}, // Combine tensors in single group
|
||||
std::make_unique<LBFGSOptions>(LBFGSOptions().max_iter(15)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
LBFGSOptions defaults;
|
||||
defaults.lr(0.8)
|
||||
.max_iter(25)
|
||||
.max_eval(31) // Use same value that appears to be auto-calculated
|
||||
.tolerance_grad(1e-5)
|
||||
.tolerance_change(1e-8)
|
||||
.history_size(80)
|
||||
.line_search_fn("strong_wolfe");
|
||||
|
||||
LBFGS optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group: max_iter preserved, others inherited
|
||||
auto& group_opts =
|
||||
static_cast<LBFGSOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group_opts.lr(), 0.8); // Inherited
|
||||
ASSERT_EQ(group_opts.max_iter(), 15); // Preserved
|
||||
ASSERT_EQ(group_opts.max_eval(), 31); // Inherited
|
||||
ASSERT_EQ(group_opts.tolerance_grad(), 1e-5); // Inherited
|
||||
ASSERT_EQ(group_opts.tolerance_change(), 1e-8); // Inherited
|
||||
ASSERT_EQ(group_opts.history_size(), 80); // Inherited
|
||||
ASSERT_EQ(group_opts.line_search_fn(), "strong_wolfe"); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_NoOptionsInheritance) {
|
||||
// Test that param groups without options get full defaults
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Groups with no options - should inherit everything
|
||||
param_groups.emplace_back(std::vector<torch::Tensor>{tensor1});
|
||||
param_groups.emplace_back(std::vector<torch::Tensor>{tensor2});
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.005)
|
||||
.betas(std::make_tuple(0.85, 0.95))
|
||||
.eps(1e-7)
|
||||
.weight_decay(0.08)
|
||||
.amsgrad(true);
|
||||
|
||||
Adam optimizer(param_groups, defaults);
|
||||
|
||||
// Both groups should have exactly the default options
|
||||
for (int i = 0; i < 2; i++) {
|
||||
auto& group_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[i].options());
|
||||
ASSERT_EQ(group_opts.lr(), 0.005);
|
||||
ASSERT_EQ(group_opts.betas(), std::make_tuple(0.85, 0.95));
|
||||
ASSERT_EQ(group_opts.eps(), 1e-7);
|
||||
ASSERT_EQ(group_opts.weight_decay(), 0.08);
|
||||
ASSERT_TRUE(group_opts.amsgrad());
|
||||
}
|
||||
}
|
||||
|
||||
// Test that field tracking survives serialization/deserialization cycles
|
||||
TEST(OptimTest, SerializationPreservesFieldTracking_Adam) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options using fluent API (marks fields as
|
||||
// explicit)
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only weight_decay and amsgrad explicitly set via fluent API
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamOptions>(
|
||||
AdamOptions().weight_decay(0.11).amsgrad(true)));
|
||||
|
||||
// Group 2: Only eps explicitly set via fluent API
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.002)
|
||||
.betas(std::make_tuple(0.8, 0.88))
|
||||
.eps(1e-12)
|
||||
.weight_decay(0.05)
|
||||
.amsgrad(false);
|
||||
|
||||
Adam original_optimizer(param_groups, defaults);
|
||||
|
||||
// Capture original state for comparison
|
||||
auto& orig_group1_opts =
|
||||
static_cast<AdamOptions&>(original_optimizer.param_groups()[0].options());
|
||||
auto& orig_group2_opts =
|
||||
static_cast<AdamOptions&>(original_optimizer.param_groups()[1].options());
|
||||
|
||||
// Verify original state (sanity check)
|
||||
ASSERT_NEAR(orig_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(orig_group1_opts.amsgrad()); // Explicitly set
|
||||
ASSERT_NEAR(orig_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_NEAR(orig_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
||||
ASSERT_NEAR(orig_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
|
||||
// Test serialization of the options objects (where field tracking lives)
|
||||
std::stringstream ss1, ss2;
|
||||
|
||||
// Serialize the parameter group options
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
orig_group1_opts.serialize(archive);
|
||||
archive.save_to(ss1);
|
||||
}
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
orig_group2_opts.serialize(archive);
|
||||
archive.save_to(ss2);
|
||||
}
|
||||
|
||||
// Create new options objects and deserialize
|
||||
AdamOptions loaded_group1_opts;
|
||||
AdamOptions loaded_group2_opts;
|
||||
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss1);
|
||||
loaded_group1_opts.serialize(archive);
|
||||
}
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss2);
|
||||
loaded_group2_opts.serialize(archive);
|
||||
}
|
||||
|
||||
// Verify that all parameter values are preserved after deserialization
|
||||
|
||||
// Group 1: weight_decay and amsgrad should be preserved as explicitly set,
|
||||
// others inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_EQ(
|
||||
loaded_group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.eps(), 1e-12, 1e-15); // Inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(loaded_group1_opts.amsgrad()); // Explicitly set
|
||||
|
||||
// Group 2: eps should be preserved as explicitly set, others inherited
|
||||
ASSERT_NEAR(loaded_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_EQ(
|
||||
loaded_group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_NEAR(loaded_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
||||
ASSERT_NEAR(loaded_group2_opts.weight_decay(), 0.05, 1e-6); // Inherited
|
||||
ASSERT_FALSE(loaded_group2_opts.amsgrad()); // Inherited
|
||||
|
||||
// CRITICAL: Test that field tracking is preserved after serialization
|
||||
// Create a new optimizer using the deserialized options to test inheritance
|
||||
auto tensor3 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor4 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
std::vector<OptimizerParamGroup> test_param_groups;
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor3},
|
||||
std::make_unique<AdamOptions>(loaded_group1_opts));
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor4},
|
||||
std::make_unique<AdamOptions>(loaded_group2_opts));
|
||||
|
||||
Adam test_optimizer(test_param_groups, defaults);
|
||||
|
||||
// The field tracking should work correctly for inheritance
|
||||
auto& final_group1_opts =
|
||||
static_cast<AdamOptions&>(test_optimizer.param_groups()[0].options());
|
||||
auto& final_group2_opts =
|
||||
static_cast<AdamOptions&>(test_optimizer.param_groups()[1].options());
|
||||
|
||||
// Group 1: weight_decay and amsgrad should still be preserved as explicitly
|
||||
// set
|
||||
ASSERT_NEAR(
|
||||
final_group1_opts.weight_decay(),
|
||||
0.11,
|
||||
1e-6); // Explicitly set (preserved)
|
||||
ASSERT_TRUE(final_group1_opts.amsgrad()); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_group1_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
||||
|
||||
// Group 2: eps should still be preserved as explicitly set
|
||||
ASSERT_NEAR(
|
||||
final_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_group2_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
||||
}
|
||||
|
||||
// Test serialization with SGD (different parameter types)
|
||||
TEST(OptimTest, SerializationPreservesFieldTracking_SGD) {
|
||||
// Create tensors
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
|
||||
// Create param group with partial options using fluent API
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<SGDOptions>(
|
||||
SGDOptions(0.01).weight_decay(0.22).nesterov(true)));
|
||||
|
||||
// Create optimizer with defaults
|
||||
SGDOptions defaults(0.001);
|
||||
defaults.momentum(0.9).dampening(0.0).weight_decay(0.05).nesterov(false);
|
||||
|
||||
SGD original_optimizer(param_groups, defaults);
|
||||
|
||||
// Test serialization of the SGD options (where field tracking lives)
|
||||
auto& original_opts =
|
||||
static_cast<SGDOptions&>(original_optimizer.param_groups()[0].options());
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
original_opts.serialize(archive);
|
||||
archive.save_to(ss);
|
||||
}
|
||||
|
||||
SGDOptions loaded_opts(0.0); // Dummy initial value
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss);
|
||||
loaded_opts.serialize(archive);
|
||||
}
|
||||
ASSERT_NEAR(loaded_opts.lr(), 0.01, 1e-6); // Explicitly set
|
||||
ASSERT_NEAR(loaded_opts.momentum(), 0.9, 1e-6); // Inherited
|
||||
ASSERT_NEAR(loaded_opts.dampening(), 0.0, 1e-6); // Inherited
|
||||
ASSERT_NEAR(loaded_opts.weight_decay(), 0.22, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(loaded_opts.nesterov()); // Explicitly set
|
||||
|
||||
// Test that field tracking still works after deserialization by creating new
|
||||
// optimizer
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
std::vector<OptimizerParamGroup> test_param_groups;
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<SGDOptions>(loaded_opts));
|
||||
|
||||
SGD test_optimizer(test_param_groups, defaults);
|
||||
|
||||
auto& final_opts =
|
||||
static_cast<SGDOptions&>(test_optimizer.param_groups()[0].options());
|
||||
ASSERT_NEAR(final_opts.lr(), 0.01, 1e-6); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(
|
||||
final_opts.weight_decay(), 0.22, 1e-6); // Explicitly set (preserved)
|
||||
ASSERT_TRUE(final_opts.nesterov()); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_opts.momentum(), 0.9, 1e-6); // Inherited from defaults
|
||||
ASSERT_NEAR(final_opts.dampening(), 0.0, 1e-6); // Inherited from defaults
|
||||
}
|
||||
|
@ -12,7 +12,6 @@
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
// Forward declarations confuse Doxygen
|
||||
@ -67,332 +66,12 @@ class TORCH_API OptimizerOptions {
|
||||
virtual void set_lr(const double lr);
|
||||
};
|
||||
|
||||
// Forward declarations for optimizer option types
|
||||
struct SGDOptions;
|
||||
struct AdamOptions;
|
||||
struct AdamWOptions;
|
||||
struct AdagradOptions;
|
||||
struct RMSpropOptions;
|
||||
struct LBFGSOptions;
|
||||
|
||||
/**
|
||||
* OptimizerCloneableOptions provides parameter group inheritance functionality
|
||||
* for PyTorch C++ optimizer options. When creating parameter groups with
|
||||
* partial options (e.g., AdamOptions().weight_decay(0.1)), fields not
|
||||
* explicitly set by the user inherit from the optimizer's default values,
|
||||
* while explicitly set fields are preserved.
|
||||
*
|
||||
* This enables Python-like behavior in C++:
|
||||
* ```cpp
|
||||
* // Python equivalent:
|
||||
* // optimizer = Adam([{'params': params1, 'weight_decay': 0.1}], lr=0.01)
|
||||
* // Result: weight_decay=0.1 preserved, lr=0.01 inherited
|
||||
*
|
||||
* AdamOptions defaults;
|
||||
* defaults.lr(0.01).weight_decay(0.05);
|
||||
*
|
||||
* std::vector<OptimizerParamGroup> groups;
|
||||
* groups.emplace_back(params1, std::make_unique<AdamOptions>(
|
||||
* AdamOptions().weight_decay(0.1))); // Only weight_decay specified
|
||||
*
|
||||
* Adam optimizer(groups, defaults);
|
||||
* // Result: group inherits lr=0.01, preserves weight_decay=0.1
|
||||
* ```
|
||||
*
|
||||
* **Implementation**: Uses SFINAE-based field detection and constructor-default
|
||||
* comparison to distinguish explicitly set fields from default values.
|
||||
* Fields that match constructor defaults are inherited; others are preserved.
|
||||
*/
|
||||
template <typename Derived>
|
||||
class OptimizerCloneableOptions : public OptimizerOptions {
|
||||
private:
|
||||
std::unique_ptr<OptimizerOptions> clone() const override {
|
||||
return std::make_unique<Derived>(static_cast<const Derived&>(*this));
|
||||
}
|
||||
|
||||
// SFINAE field detection - detects optimizer fields using public accessor
|
||||
// methods
|
||||
template <class T, class Enable = void>
|
||||
struct _has_lr : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_lr<T, std::void_t<decltype(std::declval<const T&>().lr())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_momentum : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_momentum<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().momentum())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_weight_decay : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_weight_decay<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().weight_decay())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_dampening : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_dampening<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().dampening())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_nesterov : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_nesterov<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().nesterov())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_betas : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_betas<T, std::void_t<decltype(std::declval<const T&>().betas())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_eps : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_eps<T, std::void_t<decltype(std::declval<const T&>().eps())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_amsgrad : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_amsgrad<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().amsgrad())>>
|
||||
: std::true_type {};
|
||||
|
||||
// Optimizer-specific field detection
|
||||
template <class T, class Enable = void>
|
||||
struct _has_lr_decay : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_lr_decay<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().lr_decay())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_alpha : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_alpha<T, std::void_t<decltype(std::declval<const T&>().alpha())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_centered : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_centered<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().centered())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_initial_accumulator_value : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_initial_accumulator_value<
|
||||
T,
|
||||
std::void_t<
|
||||
decltype(std::declval<const T&>().initial_accumulator_value())>>
|
||||
: std::true_type {};
|
||||
|
||||
// LBFGS-specific fields with appropriate types
|
||||
template <class T, class Enable = void>
|
||||
struct _has_max_iter : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_max_iter<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().max_iter())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_max_eval : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_max_eval<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().max_eval())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_tolerance_grad : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_tolerance_grad<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().tolerance_grad())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_tolerance_change : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_tolerance_change<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().tolerance_change())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_history_size : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_history_size<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().history_size())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_line_search_fn : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_line_search_fn<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().line_search_fn())>>
|
||||
: std::true_type {};
|
||||
|
||||
/**
|
||||
* Merges user-specified options with optimizer defaults using
|
||||
* constructor-default comparison to detect explicitly set fields.
|
||||
*
|
||||
* Algorithm:
|
||||
* 1. Start with optimizer defaults as base
|
||||
* 2. Create fresh constructor instance for comparison
|
||||
* 3. If user_value != constructor_default → user explicitly set it → preserve
|
||||
* 4. If user_value == constructor_default → user didn't set it → inherit from
|
||||
* defaults
|
||||
*/
|
||||
void _merge_by_comparison(
|
||||
const Derived& defaults,
|
||||
const Derived& user_options) {
|
||||
auto* result = static_cast<Derived*>(this);
|
||||
*result = defaults; // Start with optimizer defaults
|
||||
|
||||
// Create constructor defaults instance for comparison
|
||||
Derived constructor_defaults = []() {
|
||||
if constexpr (std::is_default_constructible_v<Derived>) {
|
||||
return Derived{};
|
||||
} else {
|
||||
// Handle optimizers requiring constructor parameters
|
||||
if constexpr (std::is_same_v<Derived, SGDOptions>) {
|
||||
return Derived(1e-3);
|
||||
} else if constexpr (std::is_same_v<Derived, AdagradOptions>) {
|
||||
return Derived(1e-2);
|
||||
} else if constexpr (std::is_same_v<Derived, RMSpropOptions>) {
|
||||
return Derived(1e-2);
|
||||
} else if constexpr (std::is_same_v<Derived, LBFGSOptions>) {
|
||||
return Derived(1);
|
||||
} else {
|
||||
return Derived{};
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
// Merge fields: preserve user-set values, inherit defaults for unset values
|
||||
|
||||
if constexpr (_has_lr<Derived>::value) {
|
||||
if (user_options.lr() != constructor_defaults.lr()) {
|
||||
result->lr(user_options.lr());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_momentum<Derived>::value) {
|
||||
if (user_options.momentum() != constructor_defaults.momentum()) {
|
||||
result->momentum(user_options.momentum());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_weight_decay<Derived>::value) {
|
||||
if (user_options.weight_decay() != constructor_defaults.weight_decay()) {
|
||||
result->weight_decay(user_options.weight_decay());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_dampening<Derived>::value) {
|
||||
if (user_options.dampening() != constructor_defaults.dampening()) {
|
||||
result->dampening(user_options.dampening());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_nesterov<Derived>::value) {
|
||||
if (user_options.nesterov() != constructor_defaults.nesterov()) {
|
||||
result->nesterov(user_options.nesterov());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_betas<Derived>::value) {
|
||||
if (user_options.betas() != constructor_defaults.betas()) {
|
||||
result->betas(user_options.betas());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_eps<Derived>::value) {
|
||||
if (user_options.eps() != constructor_defaults.eps()) {
|
||||
result->eps(user_options.eps());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_amsgrad<Derived>::value) {
|
||||
if (user_options.amsgrad() != constructor_defaults.amsgrad()) {
|
||||
result->amsgrad(user_options.amsgrad());
|
||||
}
|
||||
}
|
||||
|
||||
// Optimizer-specific fields - automatically detected and handled
|
||||
if constexpr (_has_lr_decay<Derived>::value) {
|
||||
if (user_options.lr_decay() != constructor_defaults.lr_decay()) {
|
||||
result->lr_decay(user_options.lr_decay());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_alpha<Derived>::value) {
|
||||
if (user_options.alpha() != constructor_defaults.alpha()) {
|
||||
result->alpha(user_options.alpha());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_centered<Derived>::value) {
|
||||
if (user_options.centered() != constructor_defaults.centered()) {
|
||||
result->centered(user_options.centered());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_initial_accumulator_value<Derived>::value) {
|
||||
if (user_options.initial_accumulator_value() !=
|
||||
constructor_defaults.initial_accumulator_value()) {
|
||||
result->initial_accumulator_value(
|
||||
user_options.initial_accumulator_value());
|
||||
}
|
||||
}
|
||||
|
||||
// LBFGS-specific fields with appropriate types
|
||||
if constexpr (_has_max_iter<Derived>::value) {
|
||||
if (user_options.max_iter() != constructor_defaults.max_iter()) {
|
||||
result->max_iter(user_options.max_iter());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_max_eval<Derived>::value) {
|
||||
if (user_options.max_eval() != constructor_defaults.max_eval()) {
|
||||
result->max_eval(user_options.max_eval());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_tolerance_grad<Derived>::value) {
|
||||
if (user_options.tolerance_grad() !=
|
||||
constructor_defaults.tolerance_grad()) {
|
||||
result->tolerance_grad(user_options.tolerance_grad());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_tolerance_change<Derived>::value) {
|
||||
if (user_options.tolerance_change() !=
|
||||
constructor_defaults.tolerance_change()) {
|
||||
result->tolerance_change(user_options.tolerance_change());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_history_size<Derived>::value) {
|
||||
if (user_options.history_size() != constructor_defaults.history_size()) {
|
||||
result->history_size(user_options.history_size());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_line_search_fn<Derived>::value) {
|
||||
if (user_options.line_search_fn() !=
|
||||
constructor_defaults.line_search_fn()) {
|
||||
result->line_search_fn(user_options.line_search_fn());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Friend class for controlled access to private _merge_by_comparison method
|
||||
friend class Optimizer;
|
||||
};
|
||||
|
||||
/// Stores parameters in the param_group and stores a pointer to the
|
||||
@ -507,43 +186,6 @@ class TORCH_API Optimizer {
|
||||
/// Deserializes the optimizer state from the given `archive`.
|
||||
virtual void load(serialize::InputArchive& archive);
|
||||
|
||||
private:
|
||||
/// Helper function to try merging for a specific optimizer type
|
||||
template <typename OptimizerType>
|
||||
static bool _try_merge_optimizer_type(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults) {
|
||||
auto* typed_final = dynamic_cast<OptimizerType*>(final_options.get());
|
||||
auto* typed_user = dynamic_cast<const OptimizerType*>(&user_options);
|
||||
auto* typed_defaults = dynamic_cast<const OptimizerType*>(&defaults);
|
||||
|
||||
if (typed_final && typed_user && typed_defaults) {
|
||||
typed_final->_merge_by_comparison(*typed_defaults, *typed_user);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Simple variadic dispatch helper - try all optimizer types in one call
|
||||
template <typename... OptimizerTypes>
|
||||
static void _try_merge_all_optimizer_types(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults) {
|
||||
// Try each optimizer type until one succeeds - much cleaner than manual
|
||||
// chain
|
||||
(void)(_try_merge_optimizer_type<OptimizerTypes>(
|
||||
final_options, user_options, defaults) ||
|
||||
...);
|
||||
}
|
||||
|
||||
/// Convenience function with all known PyTorch optimizers
|
||||
static void _try_merge_all_optimizers(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults);
|
||||
|
||||
protected:
|
||||
std::vector<OptimizerParamGroup> param_groups_;
|
||||
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;
|
||||
|
@ -3,35 +3,12 @@
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
// Include complete type definitions for all optimizers to enable dynamic_cast
|
||||
#include <torch/optim/adagrad.h>
|
||||
#include <torch/optim/adam.h>
|
||||
#include <torch/optim/adamw.h>
|
||||
#include <torch/optim/lbfgs.h>
|
||||
#include <torch/optim/rmsprop.h>
|
||||
#include <torch/optim/sgd.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::optim {
|
||||
|
||||
// Simple implementation using variadic template helper
|
||||
void Optimizer::_try_merge_all_optimizers(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults) {
|
||||
// Clean one-liner replaces the entire repetitive dispatch chain
|
||||
_try_merge_all_optimizer_types<
|
||||
SGDOptions,
|
||||
AdamOptions,
|
||||
AdamWOptions,
|
||||
AdagradOptions,
|
||||
RMSpropOptions,
|
||||
LBFGSOptions>(final_options, user_options, defaults);
|
||||
}
|
||||
|
||||
bool OptimizerParamGroup::has_options() const {
|
||||
return options_ != nullptr;
|
||||
}
|
||||
@ -124,20 +101,9 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
|
||||
TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
|
||||
OptimizerParamGroup param_group_(param_group.params());
|
||||
if (!param_group.has_options()) {
|
||||
// No options provided - use defaults directly
|
||||
param_group_.set_options(defaults_->clone());
|
||||
} else {
|
||||
// Options provided - merge user's explicit settings with defaults for
|
||||
// parameter group inheritance This enables Python-C++ API parity by
|
||||
// honoring user intent while inheriting missing parameters
|
||||
auto final_options = defaults_->clone();
|
||||
|
||||
// Simple variadic dispatch - try all known optimizer types
|
||||
_try_merge_all_optimizers(final_options, param_group.options(), *defaults_);
|
||||
|
||||
// If no merging was done (custom optimizer), final_options already contains
|
||||
// defaults
|
||||
param_group_.set_options(std::move(final_options));
|
||||
param_group_.set_options(param_group.options().clone());
|
||||
}
|
||||
for (const auto& p : param_group_.params()) {
|
||||
TORCH_CHECK(
|
||||
|
Reference in New Issue
Block a user