mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
C++ API handle optimizer defaults (#161825)
Fixes #141884 This fixes the issue for all optimizers and parameter options. A member function `overwrite_from` is added to the optimizer base class. Each optimizer then implements this function for comparing their accepted parameters to defaults. A SFINAE approach to handle the different optimizer parameters generically (in optimizer.h only) was evaluated, but I think this is easier to review and maintain. This mirrors the Python API up to one edge case. An example of the edge case is provided below. Python can distinguish between 1) Key not present in dict = "not specified" and 2) Key present in dict = "explicitly set". The C++ implementation cannot. The issue hinges on whether or not to track if a particular parameter was set by the user explicitly or not (discrepancy in the case when the constructor default is explicitly passed in). To track this seems like it will take more intervention than would be worth it (modify TORCH_ARG to keep track, use std::optional for the parameter types, use bitset tracking) and was not pursued in the current PR. I'm happy to alter the design if appropriate. ### Example of edge case hinging on CONSTRUCTOR DEFAULTS vs OPTIMIZER DEFAULTS 1. CONSTRUCTOR DEFAULTS: These are the values you get when calling AdamOptions() AdamOptions().lr() = 0.001 AdamOptions().weight_decay() = 0 AdamOptions().eps() = 1e-08 2. OPTIMIZER DEFAULTS: These are the values the user chose when creating the optimizer User's optimizer defaults: optimizer.lr() = 0.005 optimizer.weight_decay() = 0.1 optimizer.eps() = 1e-07 3. THE PROBLEM SCENARIO: User wants to add a parameter group with explicit weight_decay=0.0 User sets: weight_decay(0) 4. THE CONFUSION: Constructor default weight_decay: 0 User's explicit weight_decay: 0 Are they equal? YES Since they're equal, our overwrite_from() logic thinks: "User didn't set weight_decay explicitly, use optimizer default" 5. CURRENT BEHAVIOR: Final weight_decay: 0.1 User expected: 0 Match? ❌ NO === KEY INSIGHT === Constructor defaults are built into the C++ class definition. Optimizer defaults are chosen by the user at runtime. We want to respect the user intention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161825 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a3e4e894c
commit
f332017294
@ -564,3 +564,508 @@ TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
|
||||
check_lr_change_for_reduce_on_plateau(
|
||||
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
|
||||
}
|
||||
// Tests for Issue 141884: Parameter group inheritance functionality
|
||||
// Validates that partial options in parameter groups correctly inherit
|
||||
// defaults from the optimizer while preserving explicitly set values
|
||||
TEST(OptimTest, MergeWithDefaultOptions_Adam) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only weight_decay specified, should inherit lr, betas, eps,
|
||||
// amsgrad
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamOptions>(AdamOptions().weight_decay(0.11)));
|
||||
|
||||
// Group 2: Only eps specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.002)
|
||||
.betas(std::make_tuple(0.8, 0.88))
|
||||
.eps(1e-12)
|
||||
.weight_decay(0.05)
|
||||
.amsgrad(true);
|
||||
|
||||
Adam optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: weight_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.002); // Inherited
|
||||
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-12); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.11); // Preserved
|
||||
ASSERT_TRUE(group1_opts.amsgrad()); // Inherited
|
||||
|
||||
// Check Group 2: eps preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.002); // Inherited
|
||||
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-6); // Preserved
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
||||
ASSERT_TRUE(group2_opts.amsgrad()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_SGD) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only lr and weight_decay specified, should inherit momentum,
|
||||
// dampening, nesterov
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<SGDOptions>(SGDOptions(0.01).weight_decay(0.22)));
|
||||
|
||||
// Group 2: Only lr specified, should inherit all others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<SGDOptions>(SGDOptions(0.02)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
SGDOptions defaults(0.001); // lr should be overridden by param groups
|
||||
defaults.momentum(0.9)
|
||||
.dampening(0.0) // Must be 0 for Nesterov
|
||||
.weight_decay(0.05)
|
||||
.nesterov(true);
|
||||
|
||||
SGD optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: lr and weight_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<SGDOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.01); // Preserved
|
||||
ASSERT_EQ(group1_opts.momentum(), 0.9); // Inherited
|
||||
ASSERT_EQ(group1_opts.dampening(), 0.0); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.22); // Preserved
|
||||
ASSERT_TRUE(group1_opts.nesterov()); // Inherited
|
||||
|
||||
// Check Group 2: lr preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<SGDOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.02); // Preserved
|
||||
ASSERT_EQ(group2_opts.momentum(), 0.9); // Inherited
|
||||
ASSERT_EQ(group2_opts.dampening(), 0.0); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
||||
ASSERT_TRUE(group2_opts.nesterov()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_AdamW) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only eps specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamWOptions>(AdamWOptions().eps(1e-6)));
|
||||
|
||||
// Group 2: Only betas specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamWOptions>(
|
||||
AdamWOptions().betas(std::make_tuple(0.95, 0.999))));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamWOptions defaults;
|
||||
defaults.lr(0.003)
|
||||
.betas(std::make_tuple(0.9, 0.98))
|
||||
.eps(1e-8)
|
||||
.weight_decay(0.02)
|
||||
.amsgrad(false);
|
||||
|
||||
AdamW optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: eps preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdamWOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.003); // Inherited
|
||||
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.9, 0.98)); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-6); // Preserved
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.02); // Inherited
|
||||
ASSERT_FALSE(group1_opts.amsgrad()); // Inherited
|
||||
|
||||
// Check Group 2: betas preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdamWOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.003); // Inherited
|
||||
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.95, 0.999)); // Preserved
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-8); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.02); // Inherited
|
||||
ASSERT_FALSE(group2_opts.amsgrad()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_Adagrad) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only lr_decay specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdagradOptions>(AdagradOptions().lr_decay(0.001)));
|
||||
|
||||
// Group 2: Only initial_accumulator_value specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdagradOptions>(
|
||||
AdagradOptions().initial_accumulator_value(0.5)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdagradOptions defaults;
|
||||
defaults.lr(0.04)
|
||||
.lr_decay(0.002)
|
||||
.weight_decay(0.03)
|
||||
.initial_accumulator_value(0.1)
|
||||
.eps(1e-11);
|
||||
|
||||
Adagrad optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: lr_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdagradOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.04); // Inherited
|
||||
ASSERT_EQ(group1_opts.lr_decay(), 0.001); // Preserved
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.03); // Inherited
|
||||
ASSERT_EQ(group1_opts.initial_accumulator_value(), 0.1); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-11); // Inherited
|
||||
|
||||
// Check Group 2: initial_accumulator_value preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdagradOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.04); // Inherited
|
||||
ASSERT_EQ(group2_opts.lr_decay(), 0.002); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.03); // Inherited
|
||||
ASSERT_EQ(group2_opts.initial_accumulator_value(), 0.5); // Preserved
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-11); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_RMSprop) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only alpha specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<RMSpropOptions>(RMSpropOptions().alpha(0.95)));
|
||||
|
||||
// Group 2: Only momentum and centered specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<RMSpropOptions>(
|
||||
RMSpropOptions().momentum(0.8).centered(true)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
RMSpropOptions defaults;
|
||||
defaults.lr(0.015)
|
||||
.alpha(0.98)
|
||||
.eps(1e-9)
|
||||
.weight_decay(0.01)
|
||||
.momentum(0.7)
|
||||
.centered(false);
|
||||
|
||||
RMSprop optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: alpha preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<RMSpropOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.015); // Inherited
|
||||
ASSERT_EQ(group1_opts.alpha(), 0.95); // Preserved
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-9); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.01); // Inherited
|
||||
ASSERT_EQ(group1_opts.momentum(), 0.7); // Inherited
|
||||
ASSERT_FALSE(group1_opts.centered()); // Inherited
|
||||
|
||||
// Check Group 2: momentum and centered preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<RMSpropOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.015); // Inherited
|
||||
ASSERT_EQ(group2_opts.alpha(), 0.98); // Inherited
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-9); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.01); // Inherited
|
||||
ASSERT_EQ(group2_opts.momentum(), 0.8); // Preserved
|
||||
ASSERT_TRUE(group2_opts.centered()); // Preserved
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_LBFGS) {
|
||||
// Create tensors for single parameter group (LBFGS limitation)
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param group with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Single group: Only max_iter specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{
|
||||
tensor1, tensor2}, // Combine tensors in single group
|
||||
std::make_unique<LBFGSOptions>(LBFGSOptions().max_iter(15)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
LBFGSOptions defaults;
|
||||
defaults.lr(0.8)
|
||||
.max_iter(25)
|
||||
.max_eval(31) // Use same value that appears to be auto-calculated
|
||||
.tolerance_grad(1e-5)
|
||||
.tolerance_change(1e-8)
|
||||
.history_size(80)
|
||||
.line_search_fn("strong_wolfe");
|
||||
|
||||
LBFGS optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group: max_iter preserved, others inherited
|
||||
auto& group_opts =
|
||||
static_cast<LBFGSOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group_opts.lr(), 0.8); // Inherited
|
||||
ASSERT_EQ(group_opts.max_iter(), 15); // Preserved
|
||||
ASSERT_EQ(group_opts.max_eval(), 31); // Inherited
|
||||
ASSERT_EQ(group_opts.tolerance_grad(), 1e-5); // Inherited
|
||||
ASSERT_EQ(group_opts.tolerance_change(), 1e-8); // Inherited
|
||||
ASSERT_EQ(group_opts.history_size(), 80); // Inherited
|
||||
ASSERT_EQ(group_opts.line_search_fn(), "strong_wolfe"); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_NoOptionsInheritance) {
|
||||
// Test that param groups without options get full defaults
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Groups with no options - should inherit everything
|
||||
param_groups.emplace_back(std::vector<torch::Tensor>{tensor1});
|
||||
param_groups.emplace_back(std::vector<torch::Tensor>{tensor2});
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.005)
|
||||
.betas(std::make_tuple(0.85, 0.95))
|
||||
.eps(1e-7)
|
||||
.weight_decay(0.08)
|
||||
.amsgrad(true);
|
||||
|
||||
Adam optimizer(param_groups, defaults);
|
||||
|
||||
// Both groups should have exactly the default options
|
||||
for (int i = 0; i < 2; i++) {
|
||||
auto& group_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[i].options());
|
||||
ASSERT_EQ(group_opts.lr(), 0.005);
|
||||
ASSERT_EQ(group_opts.betas(), std::make_tuple(0.85, 0.95));
|
||||
ASSERT_EQ(group_opts.eps(), 1e-7);
|
||||
ASSERT_EQ(group_opts.weight_decay(), 0.08);
|
||||
ASSERT_TRUE(group_opts.amsgrad());
|
||||
}
|
||||
}
|
||||
|
||||
// Test that field tracking survives serialization/deserialization cycles
|
||||
TEST(OptimTest, SerializationPreservesFieldTracking_Adam) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options using fluent API (marks fields as
|
||||
// explicit)
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only weight_decay and amsgrad explicitly set via fluent API
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamOptions>(
|
||||
AdamOptions().weight_decay(0.11).amsgrad(true)));
|
||||
|
||||
// Group 2: Only eps explicitly set via fluent API
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.002)
|
||||
.betas(std::make_tuple(0.8, 0.88))
|
||||
.eps(1e-12)
|
||||
.weight_decay(0.05)
|
||||
.amsgrad(false);
|
||||
|
||||
Adam original_optimizer(param_groups, defaults);
|
||||
|
||||
// Capture original state for comparison
|
||||
auto& orig_group1_opts =
|
||||
static_cast<AdamOptions&>(original_optimizer.param_groups()[0].options());
|
||||
auto& orig_group2_opts =
|
||||
static_cast<AdamOptions&>(original_optimizer.param_groups()[1].options());
|
||||
|
||||
// Verify original state (sanity check)
|
||||
ASSERT_NEAR(orig_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(orig_group1_opts.amsgrad()); // Explicitly set
|
||||
ASSERT_NEAR(orig_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_NEAR(orig_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
||||
ASSERT_NEAR(orig_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
|
||||
// Test serialization of the options objects (where field tracking lives)
|
||||
std::stringstream ss1, ss2;
|
||||
|
||||
// Serialize the parameter group options
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
orig_group1_opts.serialize(archive);
|
||||
archive.save_to(ss1);
|
||||
}
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
orig_group2_opts.serialize(archive);
|
||||
archive.save_to(ss2);
|
||||
}
|
||||
|
||||
// Create new options objects and deserialize
|
||||
AdamOptions loaded_group1_opts;
|
||||
AdamOptions loaded_group2_opts;
|
||||
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss1);
|
||||
loaded_group1_opts.serialize(archive);
|
||||
}
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss2);
|
||||
loaded_group2_opts.serialize(archive);
|
||||
}
|
||||
|
||||
// Verify that all parameter values are preserved after deserialization
|
||||
|
||||
// Group 1: weight_decay and amsgrad should be preserved as explicitly set,
|
||||
// others inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_EQ(
|
||||
loaded_group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.eps(), 1e-12, 1e-15); // Inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(loaded_group1_opts.amsgrad()); // Explicitly set
|
||||
|
||||
// Group 2: eps should be preserved as explicitly set, others inherited
|
||||
ASSERT_NEAR(loaded_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_EQ(
|
||||
loaded_group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_NEAR(loaded_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
||||
ASSERT_NEAR(loaded_group2_opts.weight_decay(), 0.05, 1e-6); // Inherited
|
||||
ASSERT_FALSE(loaded_group2_opts.amsgrad()); // Inherited
|
||||
|
||||
// CRITICAL: Test that field tracking is preserved after serialization
|
||||
// Create a new optimizer using the deserialized options to test inheritance
|
||||
auto tensor3 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor4 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
std::vector<OptimizerParamGroup> test_param_groups;
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor3},
|
||||
std::make_unique<AdamOptions>(loaded_group1_opts));
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor4},
|
||||
std::make_unique<AdamOptions>(loaded_group2_opts));
|
||||
|
||||
Adam test_optimizer(test_param_groups, defaults);
|
||||
|
||||
// The field tracking should work correctly for inheritance
|
||||
auto& final_group1_opts =
|
||||
static_cast<AdamOptions&>(test_optimizer.param_groups()[0].options());
|
||||
auto& final_group2_opts =
|
||||
static_cast<AdamOptions&>(test_optimizer.param_groups()[1].options());
|
||||
|
||||
// Group 1: weight_decay and amsgrad should still be preserved as explicitly
|
||||
// set
|
||||
ASSERT_NEAR(
|
||||
final_group1_opts.weight_decay(),
|
||||
0.11,
|
||||
1e-6); // Explicitly set (preserved)
|
||||
ASSERT_TRUE(final_group1_opts.amsgrad()); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_group1_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
||||
|
||||
// Group 2: eps should still be preserved as explicitly set
|
||||
ASSERT_NEAR(
|
||||
final_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_group2_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
||||
}
|
||||
|
||||
// Test serialization with SGD (different parameter types)
|
||||
TEST(OptimTest, SerializationPreservesFieldTracking_SGD) {
|
||||
// Create tensors
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
|
||||
// Create param group with partial options using fluent API
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<SGDOptions>(
|
||||
SGDOptions(0.01).weight_decay(0.22).nesterov(true)));
|
||||
|
||||
// Create optimizer with defaults
|
||||
SGDOptions defaults(0.001);
|
||||
defaults.momentum(0.9).dampening(0.0).weight_decay(0.05).nesterov(false);
|
||||
|
||||
SGD original_optimizer(param_groups, defaults);
|
||||
|
||||
// Test serialization of the SGD options (where field tracking lives)
|
||||
auto& original_opts =
|
||||
static_cast<SGDOptions&>(original_optimizer.param_groups()[0].options());
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
original_opts.serialize(archive);
|
||||
archive.save_to(ss);
|
||||
}
|
||||
|
||||
SGDOptions loaded_opts(0.0); // Dummy initial value
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss);
|
||||
loaded_opts.serialize(archive);
|
||||
}
|
||||
ASSERT_NEAR(loaded_opts.lr(), 0.01, 1e-6); // Explicitly set
|
||||
ASSERT_NEAR(loaded_opts.momentum(), 0.9, 1e-6); // Inherited
|
||||
ASSERT_NEAR(loaded_opts.dampening(), 0.0, 1e-6); // Inherited
|
||||
ASSERT_NEAR(loaded_opts.weight_decay(), 0.22, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(loaded_opts.nesterov()); // Explicitly set
|
||||
|
||||
// Test that field tracking still works after deserialization by creating new
|
||||
// optimizer
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
std::vector<OptimizerParamGroup> test_param_groups;
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<SGDOptions>(loaded_opts));
|
||||
|
||||
SGD test_optimizer(test_param_groups, defaults);
|
||||
|
||||
auto& final_opts =
|
||||
static_cast<SGDOptions&>(test_optimizer.param_groups()[0].options());
|
||||
ASSERT_NEAR(final_opts.lr(), 0.01, 1e-6); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(
|
||||
final_opts.weight_decay(), 0.22, 1e-6); // Explicitly set (preserved)
|
||||
ASSERT_TRUE(final_opts.nesterov()); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_opts.momentum(), 0.9, 1e-6); // Inherited from defaults
|
||||
ASSERT_NEAR(final_opts.dampening(), 0.0, 1e-6); // Inherited from defaults
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
// Forward declarations confuse Doxygen
|
||||
@ -66,12 +67,332 @@ class TORCH_API OptimizerOptions {
|
||||
virtual void set_lr(const double lr);
|
||||
};
|
||||
|
||||
// Forward declarations for optimizer option types
|
||||
struct SGDOptions;
|
||||
struct AdamOptions;
|
||||
struct AdamWOptions;
|
||||
struct AdagradOptions;
|
||||
struct RMSpropOptions;
|
||||
struct LBFGSOptions;
|
||||
|
||||
/**
|
||||
* OptimizerCloneableOptions provides parameter group inheritance functionality
|
||||
* for PyTorch C++ optimizer options. When creating parameter groups with
|
||||
* partial options (e.g., AdamOptions().weight_decay(0.1)), fields not
|
||||
* explicitly set by the user inherit from the optimizer's default values,
|
||||
* while explicitly set fields are preserved.
|
||||
*
|
||||
* This enables Python-like behavior in C++:
|
||||
* ```cpp
|
||||
* // Python equivalent:
|
||||
* // optimizer = Adam([{'params': params1, 'weight_decay': 0.1}], lr=0.01)
|
||||
* // Result: weight_decay=0.1 preserved, lr=0.01 inherited
|
||||
*
|
||||
* AdamOptions defaults;
|
||||
* defaults.lr(0.01).weight_decay(0.05);
|
||||
*
|
||||
* std::vector<OptimizerParamGroup> groups;
|
||||
* groups.emplace_back(params1, std::make_unique<AdamOptions>(
|
||||
* AdamOptions().weight_decay(0.1))); // Only weight_decay specified
|
||||
*
|
||||
* Adam optimizer(groups, defaults);
|
||||
* // Result: group inherits lr=0.01, preserves weight_decay=0.1
|
||||
* ```
|
||||
*
|
||||
* **Implementation**: Uses SFINAE-based field detection and constructor-default
|
||||
* comparison to distinguish explicitly set fields from default values.
|
||||
* Fields that match constructor defaults are inherited; others are preserved.
|
||||
*/
|
||||
template <typename Derived>
|
||||
class OptimizerCloneableOptions : public OptimizerOptions {
|
||||
private:
|
||||
std::unique_ptr<OptimizerOptions> clone() const override {
|
||||
return std::make_unique<Derived>(static_cast<const Derived&>(*this));
|
||||
}
|
||||
|
||||
// SFINAE field detection - detects optimizer fields using public accessor
|
||||
// methods
|
||||
template <class T, class Enable = void>
|
||||
struct _has_lr : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_lr<T, std::void_t<decltype(std::declval<const T&>().lr())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_momentum : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_momentum<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().momentum())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_weight_decay : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_weight_decay<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().weight_decay())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_dampening : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_dampening<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().dampening())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_nesterov : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_nesterov<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().nesterov())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_betas : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_betas<T, std::void_t<decltype(std::declval<const T&>().betas())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_eps : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_eps<T, std::void_t<decltype(std::declval<const T&>().eps())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_amsgrad : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_amsgrad<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().amsgrad())>>
|
||||
: std::true_type {};
|
||||
|
||||
// Optimizer-specific field detection
|
||||
template <class T, class Enable = void>
|
||||
struct _has_lr_decay : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_lr_decay<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().lr_decay())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_alpha : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_alpha<T, std::void_t<decltype(std::declval<const T&>().alpha())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_centered : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_centered<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().centered())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_initial_accumulator_value : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_initial_accumulator_value<
|
||||
T,
|
||||
std::void_t<
|
||||
decltype(std::declval<const T&>().initial_accumulator_value())>>
|
||||
: std::true_type {};
|
||||
|
||||
// LBFGS-specific fields with appropriate types
|
||||
template <class T, class Enable = void>
|
||||
struct _has_max_iter : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_max_iter<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().max_iter())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_max_eval : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_max_eval<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().max_eval())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_tolerance_grad : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_tolerance_grad<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().tolerance_grad())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_tolerance_change : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_tolerance_change<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().tolerance_change())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_history_size : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_history_size<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().history_size())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct _has_line_search_fn : std::false_type {};
|
||||
template <class T>
|
||||
struct _has_line_search_fn<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<const T&>().line_search_fn())>>
|
||||
: std::true_type {};
|
||||
|
||||
/**
|
||||
* Merges user-specified options with optimizer defaults using
|
||||
* constructor-default comparison to detect explicitly set fields.
|
||||
*
|
||||
* Algorithm:
|
||||
* 1. Start with optimizer defaults as base
|
||||
* 2. Create fresh constructor instance for comparison
|
||||
* 3. If user_value != constructor_default → user explicitly set it → preserve
|
||||
* 4. If user_value == constructor_default → user didn't set it → inherit from
|
||||
* defaults
|
||||
*/
|
||||
void _merge_by_comparison(
|
||||
const Derived& defaults,
|
||||
const Derived& user_options) {
|
||||
auto* result = static_cast<Derived*>(this);
|
||||
*result = defaults; // Start with optimizer defaults
|
||||
|
||||
// Create constructor defaults instance for comparison
|
||||
Derived constructor_defaults = []() {
|
||||
if constexpr (std::is_default_constructible_v<Derived>) {
|
||||
return Derived{};
|
||||
} else {
|
||||
// Handle optimizers requiring constructor parameters
|
||||
if constexpr (std::is_same_v<Derived, SGDOptions>) {
|
||||
return Derived(1e-3);
|
||||
} else if constexpr (std::is_same_v<Derived, AdagradOptions>) {
|
||||
return Derived(1e-2);
|
||||
} else if constexpr (std::is_same_v<Derived, RMSpropOptions>) {
|
||||
return Derived(1e-2);
|
||||
} else if constexpr (std::is_same_v<Derived, LBFGSOptions>) {
|
||||
return Derived(1);
|
||||
} else {
|
||||
return Derived{};
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
// Merge fields: preserve user-set values, inherit defaults for unset values
|
||||
|
||||
if constexpr (_has_lr<Derived>::value) {
|
||||
if (user_options.lr() != constructor_defaults.lr()) {
|
||||
result->lr(user_options.lr());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_momentum<Derived>::value) {
|
||||
if (user_options.momentum() != constructor_defaults.momentum()) {
|
||||
result->momentum(user_options.momentum());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_weight_decay<Derived>::value) {
|
||||
if (user_options.weight_decay() != constructor_defaults.weight_decay()) {
|
||||
result->weight_decay(user_options.weight_decay());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_dampening<Derived>::value) {
|
||||
if (user_options.dampening() != constructor_defaults.dampening()) {
|
||||
result->dampening(user_options.dampening());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_nesterov<Derived>::value) {
|
||||
if (user_options.nesterov() != constructor_defaults.nesterov()) {
|
||||
result->nesterov(user_options.nesterov());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_betas<Derived>::value) {
|
||||
if (user_options.betas() != constructor_defaults.betas()) {
|
||||
result->betas(user_options.betas());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_eps<Derived>::value) {
|
||||
if (user_options.eps() != constructor_defaults.eps()) {
|
||||
result->eps(user_options.eps());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_amsgrad<Derived>::value) {
|
||||
if (user_options.amsgrad() != constructor_defaults.amsgrad()) {
|
||||
result->amsgrad(user_options.amsgrad());
|
||||
}
|
||||
}
|
||||
|
||||
// Optimizer-specific fields - automatically detected and handled
|
||||
if constexpr (_has_lr_decay<Derived>::value) {
|
||||
if (user_options.lr_decay() != constructor_defaults.lr_decay()) {
|
||||
result->lr_decay(user_options.lr_decay());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_alpha<Derived>::value) {
|
||||
if (user_options.alpha() != constructor_defaults.alpha()) {
|
||||
result->alpha(user_options.alpha());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_centered<Derived>::value) {
|
||||
if (user_options.centered() != constructor_defaults.centered()) {
|
||||
result->centered(user_options.centered());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_initial_accumulator_value<Derived>::value) {
|
||||
if (user_options.initial_accumulator_value() !=
|
||||
constructor_defaults.initial_accumulator_value()) {
|
||||
result->initial_accumulator_value(
|
||||
user_options.initial_accumulator_value());
|
||||
}
|
||||
}
|
||||
|
||||
// LBFGS-specific fields with appropriate types
|
||||
if constexpr (_has_max_iter<Derived>::value) {
|
||||
if (user_options.max_iter() != constructor_defaults.max_iter()) {
|
||||
result->max_iter(user_options.max_iter());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_max_eval<Derived>::value) {
|
||||
if (user_options.max_eval() != constructor_defaults.max_eval()) {
|
||||
result->max_eval(user_options.max_eval());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_tolerance_grad<Derived>::value) {
|
||||
if (user_options.tolerance_grad() !=
|
||||
constructor_defaults.tolerance_grad()) {
|
||||
result->tolerance_grad(user_options.tolerance_grad());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_tolerance_change<Derived>::value) {
|
||||
if (user_options.tolerance_change() !=
|
||||
constructor_defaults.tolerance_change()) {
|
||||
result->tolerance_change(user_options.tolerance_change());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_history_size<Derived>::value) {
|
||||
if (user_options.history_size() != constructor_defaults.history_size()) {
|
||||
result->history_size(user_options.history_size());
|
||||
}
|
||||
}
|
||||
if constexpr (_has_line_search_fn<Derived>::value) {
|
||||
if (user_options.line_search_fn() !=
|
||||
constructor_defaults.line_search_fn()) {
|
||||
result->line_search_fn(user_options.line_search_fn());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Friend class for controlled access to private _merge_by_comparison method
|
||||
friend class Optimizer;
|
||||
};
|
||||
|
||||
/// Stores parameters in the param_group and stores a pointer to the
|
||||
@ -186,6 +507,43 @@ class TORCH_API Optimizer {
|
||||
/// Deserializes the optimizer state from the given `archive`.
|
||||
virtual void load(serialize::InputArchive& archive);
|
||||
|
||||
private:
|
||||
/// Helper function to try merging for a specific optimizer type
|
||||
template <typename OptimizerType>
|
||||
static bool _try_merge_optimizer_type(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults) {
|
||||
auto* typed_final = dynamic_cast<OptimizerType*>(final_options.get());
|
||||
auto* typed_user = dynamic_cast<const OptimizerType*>(&user_options);
|
||||
auto* typed_defaults = dynamic_cast<const OptimizerType*>(&defaults);
|
||||
|
||||
if (typed_final && typed_user && typed_defaults) {
|
||||
typed_final->_merge_by_comparison(*typed_defaults, *typed_user);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Simple variadic dispatch helper - try all optimizer types in one call
|
||||
template <typename... OptimizerTypes>
|
||||
static void _try_merge_all_optimizer_types(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults) {
|
||||
// Try each optimizer type until one succeeds - much cleaner than manual
|
||||
// chain
|
||||
(void)(_try_merge_optimizer_type<OptimizerTypes>(
|
||||
final_options, user_options, defaults) ||
|
||||
...);
|
||||
}
|
||||
|
||||
/// Convenience function with all known PyTorch optimizers
|
||||
static void _try_merge_all_optimizers(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults);
|
||||
|
||||
protected:
|
||||
std::vector<OptimizerParamGroup> param_groups_;
|
||||
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;
|
||||
|
||||
@ -3,12 +3,35 @@
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
// Include complete type definitions for all optimizers to enable dynamic_cast
|
||||
#include <torch/optim/adagrad.h>
|
||||
#include <torch/optim/adam.h>
|
||||
#include <torch/optim/adamw.h>
|
||||
#include <torch/optim/lbfgs.h>
|
||||
#include <torch/optim/rmsprop.h>
|
||||
#include <torch/optim/sgd.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::optim {
|
||||
|
||||
// Simple implementation using variadic template helper
|
||||
void Optimizer::_try_merge_all_optimizers(
|
||||
std::unique_ptr<OptimizerOptions>& final_options,
|
||||
const OptimizerOptions& user_options,
|
||||
const OptimizerOptions& defaults) {
|
||||
// Clean one-liner replaces the entire repetitive dispatch chain
|
||||
_try_merge_all_optimizer_types<
|
||||
SGDOptions,
|
||||
AdamOptions,
|
||||
AdamWOptions,
|
||||
AdagradOptions,
|
||||
RMSpropOptions,
|
||||
LBFGSOptions>(final_options, user_options, defaults);
|
||||
}
|
||||
|
||||
bool OptimizerParamGroup::has_options() const {
|
||||
return options_ != nullptr;
|
||||
}
|
||||
@ -101,9 +124,20 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
|
||||
TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
|
||||
OptimizerParamGroup param_group_(param_group.params());
|
||||
if (!param_group.has_options()) {
|
||||
// No options provided - use defaults directly
|
||||
param_group_.set_options(defaults_->clone());
|
||||
} else {
|
||||
param_group_.set_options(param_group.options().clone());
|
||||
// Options provided - merge user's explicit settings with defaults for
|
||||
// parameter group inheritance This enables Python-C++ API parity by
|
||||
// honoring user intent while inheriting missing parameters
|
||||
auto final_options = defaults_->clone();
|
||||
|
||||
// Simple variadic dispatch - try all known optimizer types
|
||||
_try_merge_all_optimizers(final_options, param_group.options(), *defaults_);
|
||||
|
||||
// If no merging was done (custom optimizer), final_options already contains
|
||||
// defaults
|
||||
param_group_.set_options(std::move(final_options));
|
||||
}
|
||||
for (const auto& p : param_group_.params()) {
|
||||
TORCH_CHECK(
|
||||
|
||||
Reference in New Issue
Block a user