C++ API handle optimizer defaults (#161825)

Fixes #141884

This fixes the issue for all optimizers and parameter options.
A member function `overwrite_from` is added to the optimizer base class. Each optimizer then implements this function for comparing their accepted parameters to defaults. A SFINAE approach to handle the different optimizer parameters generically (in optimizer.h only) was evaluated, but I think this is easier to review and maintain.

This mirrors the Python API up to one edge case. An example of the edge case is provided below.

Python can distinguish between 1) Key not present in dict = "not specified"  and 2) Key present in dict = "explicitly set". The C++ implementation cannot.
The issue hinges on whether or not to track if a particular parameter was set by the user explicitly or not (discrepancy in the case when the constructor default is explicitly passed in).

To track this seems like it will take more intervention than would be worth it (modify TORCH_ARG to keep track, use std::optional for the parameter types, use bitset tracking) and was not pursued in the current PR. I'm happy to alter the design if appropriate.

### Example of edge case hinging on CONSTRUCTOR DEFAULTS vs OPTIMIZER DEFAULTS

1. CONSTRUCTOR DEFAULTS:
   These are the values you get when calling AdamOptions()
   AdamOptions().lr() = 0.001
   AdamOptions().weight_decay() = 0
   AdamOptions().eps() = 1e-08

2. OPTIMIZER DEFAULTS:
   These are the values the user chose when creating the optimizer
   User's optimizer defaults:
   optimizer.lr() = 0.005
   optimizer.weight_decay() = 0.1
   optimizer.eps() = 1e-07

3. THE PROBLEM SCENARIO:
   User wants to add a parameter group with explicit weight_decay=0.0
   User sets: weight_decay(0)

4. THE CONFUSION:
   Constructor default weight_decay: 0
   User's explicit weight_decay:     0
   Are they equal? YES

   Since they're equal, our overwrite_from() logic thinks:
   "User didn't set weight_decay explicitly, use optimizer default"

5. CURRENT BEHAVIOR:
   Final weight_decay: 0.1
   User expected:      0
   Match?  NO

=== KEY INSIGHT ===
Constructor defaults are built into the C++ class definition.
Optimizer defaults are chosen by the user at runtime. We want to respect the user intention.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161825
Approved by: https://github.com/janeyx99
This commit is contained in:
Sean McGovern
2025-10-08 16:40:36 +00:00
committed by PyTorch MergeBot
parent 0a3e4e894c
commit f332017294
3 changed files with 898 additions and 1 deletions

View File

@ -564,3 +564,508 @@ TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
check_lr_change_for_reduce_on_plateau(
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
}
// Tests for Issue 141884: Parameter group inheritance functionality
// Validates that partial options in parameter groups correctly inherit
// defaults from the optimizer while preserving explicitly set values
TEST(OptimTest, MergeWithDefaultOptions_Adam) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only weight_decay specified, should inherit lr, betas, eps,
// amsgrad
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamOptions>(AdamOptions().weight_decay(0.11)));
// Group 2: Only eps specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.002)
.betas(std::make_tuple(0.8, 0.88))
.eps(1e-12)
.weight_decay(0.05)
.amsgrad(true);
Adam optimizer(param_groups, defaults);
// Check Group 1: weight_decay preserved, others inherited
auto& group1_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.002); // Inherited
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-12); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.11); // Preserved
ASSERT_TRUE(group1_opts.amsgrad()); // Inherited
// Check Group 2: eps preserved, others inherited
auto& group2_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.002); // Inherited
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_EQ(group2_opts.eps(), 1e-6); // Preserved
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
ASSERT_TRUE(group2_opts.amsgrad()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_SGD) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only lr and weight_decay specified, should inherit momentum,
// dampening, nesterov
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<SGDOptions>(SGDOptions(0.01).weight_decay(0.22)));
// Group 2: Only lr specified, should inherit all others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<SGDOptions>(SGDOptions(0.02)));
// Create optimizer with specific defaults
SGDOptions defaults(0.001); // lr should be overridden by param groups
defaults.momentum(0.9)
.dampening(0.0) // Must be 0 for Nesterov
.weight_decay(0.05)
.nesterov(true);
SGD optimizer(param_groups, defaults);
// Check Group 1: lr and weight_decay preserved, others inherited
auto& group1_opts =
static_cast<SGDOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.01); // Preserved
ASSERT_EQ(group1_opts.momentum(), 0.9); // Inherited
ASSERT_EQ(group1_opts.dampening(), 0.0); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.22); // Preserved
ASSERT_TRUE(group1_opts.nesterov()); // Inherited
// Check Group 2: lr preserved, others inherited
auto& group2_opts =
static_cast<SGDOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.02); // Preserved
ASSERT_EQ(group2_opts.momentum(), 0.9); // Inherited
ASSERT_EQ(group2_opts.dampening(), 0.0); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
ASSERT_TRUE(group2_opts.nesterov()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_AdamW) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only eps specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamWOptions>(AdamWOptions().eps(1e-6)));
// Group 2: Only betas specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamWOptions>(
AdamWOptions().betas(std::make_tuple(0.95, 0.999))));
// Create optimizer with specific defaults
AdamWOptions defaults;
defaults.lr(0.003)
.betas(std::make_tuple(0.9, 0.98))
.eps(1e-8)
.weight_decay(0.02)
.amsgrad(false);
AdamW optimizer(param_groups, defaults);
// Check Group 1: eps preserved, others inherited
auto& group1_opts =
static_cast<AdamWOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.003); // Inherited
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.9, 0.98)); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-6); // Preserved
ASSERT_EQ(group1_opts.weight_decay(), 0.02); // Inherited
ASSERT_FALSE(group1_opts.amsgrad()); // Inherited
// Check Group 2: betas preserved, others inherited
auto& group2_opts =
static_cast<AdamWOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.003); // Inherited
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.95, 0.999)); // Preserved
ASSERT_EQ(group2_opts.eps(), 1e-8); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.02); // Inherited
ASSERT_FALSE(group2_opts.amsgrad()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_Adagrad) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only lr_decay specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdagradOptions>(AdagradOptions().lr_decay(0.001)));
// Group 2: Only initial_accumulator_value specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdagradOptions>(
AdagradOptions().initial_accumulator_value(0.5)));
// Create optimizer with specific defaults
AdagradOptions defaults;
defaults.lr(0.04)
.lr_decay(0.002)
.weight_decay(0.03)
.initial_accumulator_value(0.1)
.eps(1e-11);
Adagrad optimizer(param_groups, defaults);
// Check Group 1: lr_decay preserved, others inherited
auto& group1_opts =
static_cast<AdagradOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.04); // Inherited
ASSERT_EQ(group1_opts.lr_decay(), 0.001); // Preserved
ASSERT_EQ(group1_opts.weight_decay(), 0.03); // Inherited
ASSERT_EQ(group1_opts.initial_accumulator_value(), 0.1); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-11); // Inherited
// Check Group 2: initial_accumulator_value preserved, others inherited
auto& group2_opts =
static_cast<AdagradOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.04); // Inherited
ASSERT_EQ(group2_opts.lr_decay(), 0.002); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.03); // Inherited
ASSERT_EQ(group2_opts.initial_accumulator_value(), 0.5); // Preserved
ASSERT_EQ(group2_opts.eps(), 1e-11); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_RMSprop) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only alpha specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<RMSpropOptions>(RMSpropOptions().alpha(0.95)));
// Group 2: Only momentum and centered specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<RMSpropOptions>(
RMSpropOptions().momentum(0.8).centered(true)));
// Create optimizer with specific defaults
RMSpropOptions defaults;
defaults.lr(0.015)
.alpha(0.98)
.eps(1e-9)
.weight_decay(0.01)
.momentum(0.7)
.centered(false);
RMSprop optimizer(param_groups, defaults);
// Check Group 1: alpha preserved, others inherited
auto& group1_opts =
static_cast<RMSpropOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.015); // Inherited
ASSERT_EQ(group1_opts.alpha(), 0.95); // Preserved
ASSERT_EQ(group1_opts.eps(), 1e-9); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.01); // Inherited
ASSERT_EQ(group1_opts.momentum(), 0.7); // Inherited
ASSERT_FALSE(group1_opts.centered()); // Inherited
// Check Group 2: momentum and centered preserved, others inherited
auto& group2_opts =
static_cast<RMSpropOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.015); // Inherited
ASSERT_EQ(group2_opts.alpha(), 0.98); // Inherited
ASSERT_EQ(group2_opts.eps(), 1e-9); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.01); // Inherited
ASSERT_EQ(group2_opts.momentum(), 0.8); // Preserved
ASSERT_TRUE(group2_opts.centered()); // Preserved
}
TEST(OptimTest, MergeWithDefaultOptions_LBFGS) {
// Create tensors for single parameter group (LBFGS limitation)
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param group with partial options
std::vector<OptimizerParamGroup> param_groups;
// Single group: Only max_iter specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{
tensor1, tensor2}, // Combine tensors in single group
std::make_unique<LBFGSOptions>(LBFGSOptions().max_iter(15)));
// Create optimizer with specific defaults
LBFGSOptions defaults;
defaults.lr(0.8)
.max_iter(25)
.max_eval(31) // Use same value that appears to be auto-calculated
.tolerance_grad(1e-5)
.tolerance_change(1e-8)
.history_size(80)
.line_search_fn("strong_wolfe");
LBFGS optimizer(param_groups, defaults);
// Check Group: max_iter preserved, others inherited
auto& group_opts =
static_cast<LBFGSOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group_opts.lr(), 0.8); // Inherited
ASSERT_EQ(group_opts.max_iter(), 15); // Preserved
ASSERT_EQ(group_opts.max_eval(), 31); // Inherited
ASSERT_EQ(group_opts.tolerance_grad(), 1e-5); // Inherited
ASSERT_EQ(group_opts.tolerance_change(), 1e-8); // Inherited
ASSERT_EQ(group_opts.history_size(), 80); // Inherited
ASSERT_EQ(group_opts.line_search_fn(), "strong_wolfe"); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_NoOptionsInheritance) {
// Test that param groups without options get full defaults
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> param_groups;
// Groups with no options - should inherit everything
param_groups.emplace_back(std::vector<torch::Tensor>{tensor1});
param_groups.emplace_back(std::vector<torch::Tensor>{tensor2});
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.005)
.betas(std::make_tuple(0.85, 0.95))
.eps(1e-7)
.weight_decay(0.08)
.amsgrad(true);
Adam optimizer(param_groups, defaults);
// Both groups should have exactly the default options
for (int i = 0; i < 2; i++) {
auto& group_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[i].options());
ASSERT_EQ(group_opts.lr(), 0.005);
ASSERT_EQ(group_opts.betas(), std::make_tuple(0.85, 0.95));
ASSERT_EQ(group_opts.eps(), 1e-7);
ASSERT_EQ(group_opts.weight_decay(), 0.08);
ASSERT_TRUE(group_opts.amsgrad());
}
}
// Test that field tracking survives serialization/deserialization cycles
TEST(OptimTest, SerializationPreservesFieldTracking_Adam) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options using fluent API (marks fields as
// explicit)
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only weight_decay and amsgrad explicitly set via fluent API
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamOptions>(
AdamOptions().weight_decay(0.11).amsgrad(true)));
// Group 2: Only eps explicitly set via fluent API
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.002)
.betas(std::make_tuple(0.8, 0.88))
.eps(1e-12)
.weight_decay(0.05)
.amsgrad(false);
Adam original_optimizer(param_groups, defaults);
// Capture original state for comparison
auto& orig_group1_opts =
static_cast<AdamOptions&>(original_optimizer.param_groups()[0].options());
auto& orig_group2_opts =
static_cast<AdamOptions&>(original_optimizer.param_groups()[1].options());
// Verify original state (sanity check)
ASSERT_NEAR(orig_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
ASSERT_TRUE(orig_group1_opts.amsgrad()); // Explicitly set
ASSERT_NEAR(orig_group1_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_NEAR(orig_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
ASSERT_NEAR(orig_group2_opts.lr(), 0.002, 1e-6); // Inherited
// Test serialization of the options objects (where field tracking lives)
std::stringstream ss1, ss2;
// Serialize the parameter group options
{
torch::serialize::OutputArchive archive;
orig_group1_opts.serialize(archive);
archive.save_to(ss1);
}
{
torch::serialize::OutputArchive archive;
orig_group2_opts.serialize(archive);
archive.save_to(ss2);
}
// Create new options objects and deserialize
AdamOptions loaded_group1_opts;
AdamOptions loaded_group2_opts;
{
torch::serialize::InputArchive archive;
archive.load_from(ss1);
loaded_group1_opts.serialize(archive);
}
{
torch::serialize::InputArchive archive;
archive.load_from(ss2);
loaded_group2_opts.serialize(archive);
}
// Verify that all parameter values are preserved after deserialization
// Group 1: weight_decay and amsgrad should be preserved as explicitly set,
// others inherited
ASSERT_NEAR(loaded_group1_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_EQ(
loaded_group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_NEAR(loaded_group1_opts.eps(), 1e-12, 1e-15); // Inherited
ASSERT_NEAR(loaded_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
ASSERT_TRUE(loaded_group1_opts.amsgrad()); // Explicitly set
// Group 2: eps should be preserved as explicitly set, others inherited
ASSERT_NEAR(loaded_group2_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_EQ(
loaded_group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_NEAR(loaded_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
ASSERT_NEAR(loaded_group2_opts.weight_decay(), 0.05, 1e-6); // Inherited
ASSERT_FALSE(loaded_group2_opts.amsgrad()); // Inherited
// CRITICAL: Test that field tracking is preserved after serialization
// Create a new optimizer using the deserialized options to test inheritance
auto tensor3 = torch::randn({2, 2}).requires_grad_(true);
auto tensor4 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> test_param_groups;
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor3},
std::make_unique<AdamOptions>(loaded_group1_opts));
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor4},
std::make_unique<AdamOptions>(loaded_group2_opts));
Adam test_optimizer(test_param_groups, defaults);
// The field tracking should work correctly for inheritance
auto& final_group1_opts =
static_cast<AdamOptions&>(test_optimizer.param_groups()[0].options());
auto& final_group2_opts =
static_cast<AdamOptions&>(test_optimizer.param_groups()[1].options());
// Group 1: weight_decay and amsgrad should still be preserved as explicitly
// set
ASSERT_NEAR(
final_group1_opts.weight_decay(),
0.11,
1e-6); // Explicitly set (preserved)
ASSERT_TRUE(final_group1_opts.amsgrad()); // Explicitly set (preserved)
ASSERT_NEAR(final_group1_opts.lr(), 0.002, 1e-6); // Inherited from defaults
// Group 2: eps should still be preserved as explicitly set
ASSERT_NEAR(
final_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set (preserved)
ASSERT_NEAR(final_group2_opts.lr(), 0.002, 1e-6); // Inherited from defaults
}
// Test serialization with SGD (different parameter types)
TEST(OptimTest, SerializationPreservesFieldTracking_SGD) {
// Create tensors
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
// Create param group with partial options using fluent API
std::vector<OptimizerParamGroup> param_groups;
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<SGDOptions>(
SGDOptions(0.01).weight_decay(0.22).nesterov(true)));
// Create optimizer with defaults
SGDOptions defaults(0.001);
defaults.momentum(0.9).dampening(0.0).weight_decay(0.05).nesterov(false);
SGD original_optimizer(param_groups, defaults);
// Test serialization of the SGD options (where field tracking lives)
auto& original_opts =
static_cast<SGDOptions&>(original_optimizer.param_groups()[0].options());
std::stringstream ss;
{
torch::serialize::OutputArchive archive;
original_opts.serialize(archive);
archive.save_to(ss);
}
SGDOptions loaded_opts(0.0); // Dummy initial value
{
torch::serialize::InputArchive archive;
archive.load_from(ss);
loaded_opts.serialize(archive);
}
ASSERT_NEAR(loaded_opts.lr(), 0.01, 1e-6); // Explicitly set
ASSERT_NEAR(loaded_opts.momentum(), 0.9, 1e-6); // Inherited
ASSERT_NEAR(loaded_opts.dampening(), 0.0, 1e-6); // Inherited
ASSERT_NEAR(loaded_opts.weight_decay(), 0.22, 1e-6); // Explicitly set
ASSERT_TRUE(loaded_opts.nesterov()); // Explicitly set
// Test that field tracking still works after deserialization by creating new
// optimizer
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> test_param_groups;
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<SGDOptions>(loaded_opts));
SGD test_optimizer(test_param_groups, defaults);
auto& final_opts =
static_cast<SGDOptions&>(test_optimizer.param_groups()[0].options());
ASSERT_NEAR(final_opts.lr(), 0.01, 1e-6); // Explicitly set (preserved)
ASSERT_NEAR(
final_opts.weight_decay(), 0.22, 1e-6); // Explicitly set (preserved)
ASSERT_TRUE(final_opts.nesterov()); // Explicitly set (preserved)
ASSERT_NEAR(final_opts.momentum(), 0.9, 1e-6); // Inherited from defaults
ASSERT_NEAR(final_opts.dampening(), 0.0, 1e-6); // Inherited from defaults
}

View File

@ -12,6 +12,7 @@
#include <iterator>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
// Forward declarations confuse Doxygen
@ -66,12 +67,332 @@ class TORCH_API OptimizerOptions {
virtual void set_lr(const double lr);
};
// Forward declarations for optimizer option types
struct SGDOptions;
struct AdamOptions;
struct AdamWOptions;
struct AdagradOptions;
struct RMSpropOptions;
struct LBFGSOptions;
/**
* OptimizerCloneableOptions provides parameter group inheritance functionality
* for PyTorch C++ optimizer options. When creating parameter groups with
* partial options (e.g., AdamOptions().weight_decay(0.1)), fields not
* explicitly set by the user inherit from the optimizer's default values,
* while explicitly set fields are preserved.
*
* This enables Python-like behavior in C++:
* ```cpp
* // Python equivalent:
* // optimizer = Adam([{'params': params1, 'weight_decay': 0.1}], lr=0.01)
* // Result: weight_decay=0.1 preserved, lr=0.01 inherited
*
* AdamOptions defaults;
* defaults.lr(0.01).weight_decay(0.05);
*
* std::vector<OptimizerParamGroup> groups;
* groups.emplace_back(params1, std::make_unique<AdamOptions>(
* AdamOptions().weight_decay(0.1))); // Only weight_decay specified
*
* Adam optimizer(groups, defaults);
* // Result: group inherits lr=0.01, preserves weight_decay=0.1
* ```
*
* **Implementation**: Uses SFINAE-based field detection and constructor-default
* comparison to distinguish explicitly set fields from default values.
* Fields that match constructor defaults are inherited; others are preserved.
*/
template <typename Derived>
class OptimizerCloneableOptions : public OptimizerOptions {
private:
std::unique_ptr<OptimizerOptions> clone() const override {
return std::make_unique<Derived>(static_cast<const Derived&>(*this));
}
// SFINAE field detection - detects optimizer fields using public accessor
// methods
template <class T, class Enable = void>
struct _has_lr : std::false_type {};
template <class T>
struct _has_lr<T, std::void_t<decltype(std::declval<const T&>().lr())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_momentum : std::false_type {};
template <class T>
struct _has_momentum<
T,
std::void_t<decltype(std::declval<const T&>().momentum())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_weight_decay : std::false_type {};
template <class T>
struct _has_weight_decay<
T,
std::void_t<decltype(std::declval<const T&>().weight_decay())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_dampening : std::false_type {};
template <class T>
struct _has_dampening<
T,
std::void_t<decltype(std::declval<const T&>().dampening())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_nesterov : std::false_type {};
template <class T>
struct _has_nesterov<
T,
std::void_t<decltype(std::declval<const T&>().nesterov())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_betas : std::false_type {};
template <class T>
struct _has_betas<T, std::void_t<decltype(std::declval<const T&>().betas())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_eps : std::false_type {};
template <class T>
struct _has_eps<T, std::void_t<decltype(std::declval<const T&>().eps())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_amsgrad : std::false_type {};
template <class T>
struct _has_amsgrad<
T,
std::void_t<decltype(std::declval<const T&>().amsgrad())>>
: std::true_type {};
// Optimizer-specific field detection
template <class T, class Enable = void>
struct _has_lr_decay : std::false_type {};
template <class T>
struct _has_lr_decay<
T,
std::void_t<decltype(std::declval<const T&>().lr_decay())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_alpha : std::false_type {};
template <class T>
struct _has_alpha<T, std::void_t<decltype(std::declval<const T&>().alpha())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_centered : std::false_type {};
template <class T>
struct _has_centered<
T,
std::void_t<decltype(std::declval<const T&>().centered())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_initial_accumulator_value : std::false_type {};
template <class T>
struct _has_initial_accumulator_value<
T,
std::void_t<
decltype(std::declval<const T&>().initial_accumulator_value())>>
: std::true_type {};
// LBFGS-specific fields with appropriate types
template <class T, class Enable = void>
struct _has_max_iter : std::false_type {};
template <class T>
struct _has_max_iter<
T,
std::void_t<decltype(std::declval<const T&>().max_iter())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_max_eval : std::false_type {};
template <class T>
struct _has_max_eval<
T,
std::void_t<decltype(std::declval<const T&>().max_eval())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_tolerance_grad : std::false_type {};
template <class T>
struct _has_tolerance_grad<
T,
std::void_t<decltype(std::declval<const T&>().tolerance_grad())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_tolerance_change : std::false_type {};
template <class T>
struct _has_tolerance_change<
T,
std::void_t<decltype(std::declval<const T&>().tolerance_change())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_history_size : std::false_type {};
template <class T>
struct _has_history_size<
T,
std::void_t<decltype(std::declval<const T&>().history_size())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_line_search_fn : std::false_type {};
template <class T>
struct _has_line_search_fn<
T,
std::void_t<decltype(std::declval<const T&>().line_search_fn())>>
: std::true_type {};
/**
* Merges user-specified options with optimizer defaults using
* constructor-default comparison to detect explicitly set fields.
*
* Algorithm:
* 1. Start with optimizer defaults as base
* 2. Create fresh constructor instance for comparison
* 3. If user_value != constructor_default → user explicitly set it → preserve
* 4. If user_value == constructor_default → user didn't set it → inherit from
* defaults
*/
void _merge_by_comparison(
const Derived& defaults,
const Derived& user_options) {
auto* result = static_cast<Derived*>(this);
*result = defaults; // Start with optimizer defaults
// Create constructor defaults instance for comparison
Derived constructor_defaults = []() {
if constexpr (std::is_default_constructible_v<Derived>) {
return Derived{};
} else {
// Handle optimizers requiring constructor parameters
if constexpr (std::is_same_v<Derived, SGDOptions>) {
return Derived(1e-3);
} else if constexpr (std::is_same_v<Derived, AdagradOptions>) {
return Derived(1e-2);
} else if constexpr (std::is_same_v<Derived, RMSpropOptions>) {
return Derived(1e-2);
} else if constexpr (std::is_same_v<Derived, LBFGSOptions>) {
return Derived(1);
} else {
return Derived{};
}
}
}();
// Merge fields: preserve user-set values, inherit defaults for unset values
if constexpr (_has_lr<Derived>::value) {
if (user_options.lr() != constructor_defaults.lr()) {
result->lr(user_options.lr());
}
}
if constexpr (_has_momentum<Derived>::value) {
if (user_options.momentum() != constructor_defaults.momentum()) {
result->momentum(user_options.momentum());
}
}
if constexpr (_has_weight_decay<Derived>::value) {
if (user_options.weight_decay() != constructor_defaults.weight_decay()) {
result->weight_decay(user_options.weight_decay());
}
}
if constexpr (_has_dampening<Derived>::value) {
if (user_options.dampening() != constructor_defaults.dampening()) {
result->dampening(user_options.dampening());
}
}
if constexpr (_has_nesterov<Derived>::value) {
if (user_options.nesterov() != constructor_defaults.nesterov()) {
result->nesterov(user_options.nesterov());
}
}
if constexpr (_has_betas<Derived>::value) {
if (user_options.betas() != constructor_defaults.betas()) {
result->betas(user_options.betas());
}
}
if constexpr (_has_eps<Derived>::value) {
if (user_options.eps() != constructor_defaults.eps()) {
result->eps(user_options.eps());
}
}
if constexpr (_has_amsgrad<Derived>::value) {
if (user_options.amsgrad() != constructor_defaults.amsgrad()) {
result->amsgrad(user_options.amsgrad());
}
}
// Optimizer-specific fields - automatically detected and handled
if constexpr (_has_lr_decay<Derived>::value) {
if (user_options.lr_decay() != constructor_defaults.lr_decay()) {
result->lr_decay(user_options.lr_decay());
}
}
if constexpr (_has_alpha<Derived>::value) {
if (user_options.alpha() != constructor_defaults.alpha()) {
result->alpha(user_options.alpha());
}
}
if constexpr (_has_centered<Derived>::value) {
if (user_options.centered() != constructor_defaults.centered()) {
result->centered(user_options.centered());
}
}
if constexpr (_has_initial_accumulator_value<Derived>::value) {
if (user_options.initial_accumulator_value() !=
constructor_defaults.initial_accumulator_value()) {
result->initial_accumulator_value(
user_options.initial_accumulator_value());
}
}
// LBFGS-specific fields with appropriate types
if constexpr (_has_max_iter<Derived>::value) {
if (user_options.max_iter() != constructor_defaults.max_iter()) {
result->max_iter(user_options.max_iter());
}
}
if constexpr (_has_max_eval<Derived>::value) {
if (user_options.max_eval() != constructor_defaults.max_eval()) {
result->max_eval(user_options.max_eval());
}
}
if constexpr (_has_tolerance_grad<Derived>::value) {
if (user_options.tolerance_grad() !=
constructor_defaults.tolerance_grad()) {
result->tolerance_grad(user_options.tolerance_grad());
}
}
if constexpr (_has_tolerance_change<Derived>::value) {
if (user_options.tolerance_change() !=
constructor_defaults.tolerance_change()) {
result->tolerance_change(user_options.tolerance_change());
}
}
if constexpr (_has_history_size<Derived>::value) {
if (user_options.history_size() != constructor_defaults.history_size()) {
result->history_size(user_options.history_size());
}
}
if constexpr (_has_line_search_fn<Derived>::value) {
if (user_options.line_search_fn() !=
constructor_defaults.line_search_fn()) {
result->line_search_fn(user_options.line_search_fn());
}
}
}
// Friend class for controlled access to private _merge_by_comparison method
friend class Optimizer;
};
/// Stores parameters in the param_group and stores a pointer to the
@ -186,6 +507,43 @@ class TORCH_API Optimizer {
/// Deserializes the optimizer state from the given `archive`.
virtual void load(serialize::InputArchive& archive);
private:
/// Helper function to try merging for a specific optimizer type
template <typename OptimizerType>
static bool _try_merge_optimizer_type(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults) {
auto* typed_final = dynamic_cast<OptimizerType*>(final_options.get());
auto* typed_user = dynamic_cast<const OptimizerType*>(&user_options);
auto* typed_defaults = dynamic_cast<const OptimizerType*>(&defaults);
if (typed_final && typed_user && typed_defaults) {
typed_final->_merge_by_comparison(*typed_defaults, *typed_user);
return true;
}
return false;
}
/// Simple variadic dispatch helper - try all optimizer types in one call
template <typename... OptimizerTypes>
static void _try_merge_all_optimizer_types(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults) {
// Try each optimizer type until one succeeds - much cleaner than manual
// chain
(void)(_try_merge_optimizer_type<OptimizerTypes>(
final_options, user_options, defaults) ||
...);
}
/// Convenience function with all known PyTorch optimizers
static void _try_merge_all_optimizers(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults);
protected:
std::vector<OptimizerParamGroup> param_groups_;
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;

View File

@ -3,12 +3,35 @@
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/types.h>
// Include complete type definitions for all optimizers to enable dynamic_cast
#include <torch/optim/adagrad.h>
#include <torch/optim/adam.h>
#include <torch/optim/adamw.h>
#include <torch/optim/lbfgs.h>
#include <torch/optim/rmsprop.h>
#include <torch/optim/sgd.h>
#include <string>
#include <utility>
#include <vector>
namespace torch::optim {
// Simple implementation using variadic template helper
void Optimizer::_try_merge_all_optimizers(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults) {
// Clean one-liner replaces the entire repetitive dispatch chain
_try_merge_all_optimizer_types<
SGDOptions,
AdamOptions,
AdamWOptions,
AdagradOptions,
RMSpropOptions,
LBFGSOptions>(final_options, user_options, defaults);
}
bool OptimizerParamGroup::has_options() const {
return options_ != nullptr;
}
@ -101,9 +124,20 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
OptimizerParamGroup param_group_(param_group.params());
if (!param_group.has_options()) {
// No options provided - use defaults directly
param_group_.set_options(defaults_->clone());
} else {
param_group_.set_options(param_group.options().clone());
// Options provided - merge user's explicit settings with defaults for
// parameter group inheritance This enables Python-C++ API parity by
// honoring user intent while inheriting missing parameters
auto final_options = defaults_->clone();
// Simple variadic dispatch - try all known optimizer types
_try_merge_all_optimizers(final_options, param_group.options(), *defaults_);
// If no merging was done (custom optimizer), final_options already contains
// defaults
param_group_.set_options(std::move(final_options));
}
for (const auto& p : param_group_.params()) {
TORCH_CHECK(