Revert "C++ API handle optimizer defaults (#161825)"

This reverts commit f33201729416ed17467228e80b04d01d4d02b5f3.

Reverted https://github.com/pytorch/pytorch/pull/161825 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/161825#issuecomment-3391506427))
This commit is contained in:
PyTorch MergeBot
2025-10-10 17:56:08 +00:00
parent 4cd06dc82c
commit b67785d9eb
3 changed files with 1 additions and 898 deletions

View File

@ -564,508 +564,3 @@ TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
check_lr_change_for_reduce_on_plateau(
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
}
// Tests for Issue 141884: Parameter group inheritance functionality
// Validates that partial options in parameter groups correctly inherit
// defaults from the optimizer while preserving explicitly set values
TEST(OptimTest, MergeWithDefaultOptions_Adam) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only weight_decay specified, should inherit lr, betas, eps,
// amsgrad
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamOptions>(AdamOptions().weight_decay(0.11)));
// Group 2: Only eps specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.002)
.betas(std::make_tuple(0.8, 0.88))
.eps(1e-12)
.weight_decay(0.05)
.amsgrad(true);
Adam optimizer(param_groups, defaults);
// Check Group 1: weight_decay preserved, others inherited
auto& group1_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.002); // Inherited
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-12); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.11); // Preserved
ASSERT_TRUE(group1_opts.amsgrad()); // Inherited
// Check Group 2: eps preserved, others inherited
auto& group2_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.002); // Inherited
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_EQ(group2_opts.eps(), 1e-6); // Preserved
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
ASSERT_TRUE(group2_opts.amsgrad()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_SGD) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only lr and weight_decay specified, should inherit momentum,
// dampening, nesterov
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<SGDOptions>(SGDOptions(0.01).weight_decay(0.22)));
// Group 2: Only lr specified, should inherit all others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<SGDOptions>(SGDOptions(0.02)));
// Create optimizer with specific defaults
SGDOptions defaults(0.001); // lr should be overridden by param groups
defaults.momentum(0.9)
.dampening(0.0) // Must be 0 for Nesterov
.weight_decay(0.05)
.nesterov(true);
SGD optimizer(param_groups, defaults);
// Check Group 1: lr and weight_decay preserved, others inherited
auto& group1_opts =
static_cast<SGDOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.01); // Preserved
ASSERT_EQ(group1_opts.momentum(), 0.9); // Inherited
ASSERT_EQ(group1_opts.dampening(), 0.0); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.22); // Preserved
ASSERT_TRUE(group1_opts.nesterov()); // Inherited
// Check Group 2: lr preserved, others inherited
auto& group2_opts =
static_cast<SGDOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.02); // Preserved
ASSERT_EQ(group2_opts.momentum(), 0.9); // Inherited
ASSERT_EQ(group2_opts.dampening(), 0.0); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
ASSERT_TRUE(group2_opts.nesterov()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_AdamW) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only eps specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamWOptions>(AdamWOptions().eps(1e-6)));
// Group 2: Only betas specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamWOptions>(
AdamWOptions().betas(std::make_tuple(0.95, 0.999))));
// Create optimizer with specific defaults
AdamWOptions defaults;
defaults.lr(0.003)
.betas(std::make_tuple(0.9, 0.98))
.eps(1e-8)
.weight_decay(0.02)
.amsgrad(false);
AdamW optimizer(param_groups, defaults);
// Check Group 1: eps preserved, others inherited
auto& group1_opts =
static_cast<AdamWOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.003); // Inherited
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.9, 0.98)); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-6); // Preserved
ASSERT_EQ(group1_opts.weight_decay(), 0.02); // Inherited
ASSERT_FALSE(group1_opts.amsgrad()); // Inherited
// Check Group 2: betas preserved, others inherited
auto& group2_opts =
static_cast<AdamWOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.003); // Inherited
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.95, 0.999)); // Preserved
ASSERT_EQ(group2_opts.eps(), 1e-8); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.02); // Inherited
ASSERT_FALSE(group2_opts.amsgrad()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_Adagrad) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only lr_decay specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdagradOptions>(AdagradOptions().lr_decay(0.001)));
// Group 2: Only initial_accumulator_value specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdagradOptions>(
AdagradOptions().initial_accumulator_value(0.5)));
// Create optimizer with specific defaults
AdagradOptions defaults;
defaults.lr(0.04)
.lr_decay(0.002)
.weight_decay(0.03)
.initial_accumulator_value(0.1)
.eps(1e-11);
Adagrad optimizer(param_groups, defaults);
// Check Group 1: lr_decay preserved, others inherited
auto& group1_opts =
static_cast<AdagradOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.04); // Inherited
ASSERT_EQ(group1_opts.lr_decay(), 0.001); // Preserved
ASSERT_EQ(group1_opts.weight_decay(), 0.03); // Inherited
ASSERT_EQ(group1_opts.initial_accumulator_value(), 0.1); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-11); // Inherited
// Check Group 2: initial_accumulator_value preserved, others inherited
auto& group2_opts =
static_cast<AdagradOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.04); // Inherited
ASSERT_EQ(group2_opts.lr_decay(), 0.002); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.03); // Inherited
ASSERT_EQ(group2_opts.initial_accumulator_value(), 0.5); // Preserved
ASSERT_EQ(group2_opts.eps(), 1e-11); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_RMSprop) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only alpha specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<RMSpropOptions>(RMSpropOptions().alpha(0.95)));
// Group 2: Only momentum and centered specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<RMSpropOptions>(
RMSpropOptions().momentum(0.8).centered(true)));
// Create optimizer with specific defaults
RMSpropOptions defaults;
defaults.lr(0.015)
.alpha(0.98)
.eps(1e-9)
.weight_decay(0.01)
.momentum(0.7)
.centered(false);
RMSprop optimizer(param_groups, defaults);
// Check Group 1: alpha preserved, others inherited
auto& group1_opts =
static_cast<RMSpropOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.015); // Inherited
ASSERT_EQ(group1_opts.alpha(), 0.95); // Preserved
ASSERT_EQ(group1_opts.eps(), 1e-9); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.01); // Inherited
ASSERT_EQ(group1_opts.momentum(), 0.7); // Inherited
ASSERT_FALSE(group1_opts.centered()); // Inherited
// Check Group 2: momentum and centered preserved, others inherited
auto& group2_opts =
static_cast<RMSpropOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.015); // Inherited
ASSERT_EQ(group2_opts.alpha(), 0.98); // Inherited
ASSERT_EQ(group2_opts.eps(), 1e-9); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.01); // Inherited
ASSERT_EQ(group2_opts.momentum(), 0.8); // Preserved
ASSERT_TRUE(group2_opts.centered()); // Preserved
}
TEST(OptimTest, MergeWithDefaultOptions_LBFGS) {
// Create tensors for single parameter group (LBFGS limitation)
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param group with partial options
std::vector<OptimizerParamGroup> param_groups;
// Single group: Only max_iter specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{
tensor1, tensor2}, // Combine tensors in single group
std::make_unique<LBFGSOptions>(LBFGSOptions().max_iter(15)));
// Create optimizer with specific defaults
LBFGSOptions defaults;
defaults.lr(0.8)
.max_iter(25)
.max_eval(31) // Use same value that appears to be auto-calculated
.tolerance_grad(1e-5)
.tolerance_change(1e-8)
.history_size(80)
.line_search_fn("strong_wolfe");
LBFGS optimizer(param_groups, defaults);
// Check Group: max_iter preserved, others inherited
auto& group_opts =
static_cast<LBFGSOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group_opts.lr(), 0.8); // Inherited
ASSERT_EQ(group_opts.max_iter(), 15); // Preserved
ASSERT_EQ(group_opts.max_eval(), 31); // Inherited
ASSERT_EQ(group_opts.tolerance_grad(), 1e-5); // Inherited
ASSERT_EQ(group_opts.tolerance_change(), 1e-8); // Inherited
ASSERT_EQ(group_opts.history_size(), 80); // Inherited
ASSERT_EQ(group_opts.line_search_fn(), "strong_wolfe"); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_NoOptionsInheritance) {
// Test that param groups without options get full defaults
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> param_groups;
// Groups with no options - should inherit everything
param_groups.emplace_back(std::vector<torch::Tensor>{tensor1});
param_groups.emplace_back(std::vector<torch::Tensor>{tensor2});
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.005)
.betas(std::make_tuple(0.85, 0.95))
.eps(1e-7)
.weight_decay(0.08)
.amsgrad(true);
Adam optimizer(param_groups, defaults);
// Both groups should have exactly the default options
for (int i = 0; i < 2; i++) {
auto& group_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[i].options());
ASSERT_EQ(group_opts.lr(), 0.005);
ASSERT_EQ(group_opts.betas(), std::make_tuple(0.85, 0.95));
ASSERT_EQ(group_opts.eps(), 1e-7);
ASSERT_EQ(group_opts.weight_decay(), 0.08);
ASSERT_TRUE(group_opts.amsgrad());
}
}
// Test that field tracking survives serialization/deserialization cycles
TEST(OptimTest, SerializationPreservesFieldTracking_Adam) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options using fluent API (marks fields as
// explicit)
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only weight_decay and amsgrad explicitly set via fluent API
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamOptions>(
AdamOptions().weight_decay(0.11).amsgrad(true)));
// Group 2: Only eps explicitly set via fluent API
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.002)
.betas(std::make_tuple(0.8, 0.88))
.eps(1e-12)
.weight_decay(0.05)
.amsgrad(false);
Adam original_optimizer(param_groups, defaults);
// Capture original state for comparison
auto& orig_group1_opts =
static_cast<AdamOptions&>(original_optimizer.param_groups()[0].options());
auto& orig_group2_opts =
static_cast<AdamOptions&>(original_optimizer.param_groups()[1].options());
// Verify original state (sanity check)
ASSERT_NEAR(orig_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
ASSERT_TRUE(orig_group1_opts.amsgrad()); // Explicitly set
ASSERT_NEAR(orig_group1_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_NEAR(orig_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
ASSERT_NEAR(orig_group2_opts.lr(), 0.002, 1e-6); // Inherited
// Test serialization of the options objects (where field tracking lives)
std::stringstream ss1, ss2;
// Serialize the parameter group options
{
torch::serialize::OutputArchive archive;
orig_group1_opts.serialize(archive);
archive.save_to(ss1);
}
{
torch::serialize::OutputArchive archive;
orig_group2_opts.serialize(archive);
archive.save_to(ss2);
}
// Create new options objects and deserialize
AdamOptions loaded_group1_opts;
AdamOptions loaded_group2_opts;
{
torch::serialize::InputArchive archive;
archive.load_from(ss1);
loaded_group1_opts.serialize(archive);
}
{
torch::serialize::InputArchive archive;
archive.load_from(ss2);
loaded_group2_opts.serialize(archive);
}
// Verify that all parameter values are preserved after deserialization
// Group 1: weight_decay and amsgrad should be preserved as explicitly set,
// others inherited
ASSERT_NEAR(loaded_group1_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_EQ(
loaded_group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_NEAR(loaded_group1_opts.eps(), 1e-12, 1e-15); // Inherited
ASSERT_NEAR(loaded_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
ASSERT_TRUE(loaded_group1_opts.amsgrad()); // Explicitly set
// Group 2: eps should be preserved as explicitly set, others inherited
ASSERT_NEAR(loaded_group2_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_EQ(
loaded_group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_NEAR(loaded_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
ASSERT_NEAR(loaded_group2_opts.weight_decay(), 0.05, 1e-6); // Inherited
ASSERT_FALSE(loaded_group2_opts.amsgrad()); // Inherited
// CRITICAL: Test that field tracking is preserved after serialization
// Create a new optimizer using the deserialized options to test inheritance
auto tensor3 = torch::randn({2, 2}).requires_grad_(true);
auto tensor4 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> test_param_groups;
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor3},
std::make_unique<AdamOptions>(loaded_group1_opts));
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor4},
std::make_unique<AdamOptions>(loaded_group2_opts));
Adam test_optimizer(test_param_groups, defaults);
// The field tracking should work correctly for inheritance
auto& final_group1_opts =
static_cast<AdamOptions&>(test_optimizer.param_groups()[0].options());
auto& final_group2_opts =
static_cast<AdamOptions&>(test_optimizer.param_groups()[1].options());
// Group 1: weight_decay and amsgrad should still be preserved as explicitly
// set
ASSERT_NEAR(
final_group1_opts.weight_decay(),
0.11,
1e-6); // Explicitly set (preserved)
ASSERT_TRUE(final_group1_opts.amsgrad()); // Explicitly set (preserved)
ASSERT_NEAR(final_group1_opts.lr(), 0.002, 1e-6); // Inherited from defaults
// Group 2: eps should still be preserved as explicitly set
ASSERT_NEAR(
final_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set (preserved)
ASSERT_NEAR(final_group2_opts.lr(), 0.002, 1e-6); // Inherited from defaults
}
// Test serialization with SGD (different parameter types)
TEST(OptimTest, SerializationPreservesFieldTracking_SGD) {
// Create tensors
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
// Create param group with partial options using fluent API
std::vector<OptimizerParamGroup> param_groups;
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<SGDOptions>(
SGDOptions(0.01).weight_decay(0.22).nesterov(true)));
// Create optimizer with defaults
SGDOptions defaults(0.001);
defaults.momentum(0.9).dampening(0.0).weight_decay(0.05).nesterov(false);
SGD original_optimizer(param_groups, defaults);
// Test serialization of the SGD options (where field tracking lives)
auto& original_opts =
static_cast<SGDOptions&>(original_optimizer.param_groups()[0].options());
std::stringstream ss;
{
torch::serialize::OutputArchive archive;
original_opts.serialize(archive);
archive.save_to(ss);
}
SGDOptions loaded_opts(0.0); // Dummy initial value
{
torch::serialize::InputArchive archive;
archive.load_from(ss);
loaded_opts.serialize(archive);
}
ASSERT_NEAR(loaded_opts.lr(), 0.01, 1e-6); // Explicitly set
ASSERT_NEAR(loaded_opts.momentum(), 0.9, 1e-6); // Inherited
ASSERT_NEAR(loaded_opts.dampening(), 0.0, 1e-6); // Inherited
ASSERT_NEAR(loaded_opts.weight_decay(), 0.22, 1e-6); // Explicitly set
ASSERT_TRUE(loaded_opts.nesterov()); // Explicitly set
// Test that field tracking still works after deserialization by creating new
// optimizer
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> test_param_groups;
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<SGDOptions>(loaded_opts));
SGD test_optimizer(test_param_groups, defaults);
auto& final_opts =
static_cast<SGDOptions&>(test_optimizer.param_groups()[0].options());
ASSERT_NEAR(final_opts.lr(), 0.01, 1e-6); // Explicitly set (preserved)
ASSERT_NEAR(
final_opts.weight_decay(), 0.22, 1e-6); // Explicitly set (preserved)
ASSERT_TRUE(final_opts.nesterov()); // Explicitly set (preserved)
ASSERT_NEAR(final_opts.momentum(), 0.9, 1e-6); // Inherited from defaults
ASSERT_NEAR(final_opts.dampening(), 0.0, 1e-6); // Inherited from defaults
}

View File

@ -12,7 +12,6 @@
#include <iterator>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
// Forward declarations confuse Doxygen
@ -67,332 +66,12 @@ class TORCH_API OptimizerOptions {
virtual void set_lr(const double lr);
};
// Forward declarations for optimizer option types
struct SGDOptions;
struct AdamOptions;
struct AdamWOptions;
struct AdagradOptions;
struct RMSpropOptions;
struct LBFGSOptions;
/**
* OptimizerCloneableOptions provides parameter group inheritance functionality
* for PyTorch C++ optimizer options. When creating parameter groups with
* partial options (e.g., AdamOptions().weight_decay(0.1)), fields not
* explicitly set by the user inherit from the optimizer's default values,
* while explicitly set fields are preserved.
*
* This enables Python-like behavior in C++:
* ```cpp
* // Python equivalent:
* // optimizer = Adam([{'params': params1, 'weight_decay': 0.1}], lr=0.01)
* // Result: weight_decay=0.1 preserved, lr=0.01 inherited
*
* AdamOptions defaults;
* defaults.lr(0.01).weight_decay(0.05);
*
* std::vector<OptimizerParamGroup> groups;
* groups.emplace_back(params1, std::make_unique<AdamOptions>(
* AdamOptions().weight_decay(0.1))); // Only weight_decay specified
*
* Adam optimizer(groups, defaults);
* // Result: group inherits lr=0.01, preserves weight_decay=0.1
* ```
*
* **Implementation**: Uses SFINAE-based field detection and constructor-default
* comparison to distinguish explicitly set fields from default values.
* Fields that match constructor defaults are inherited; others are preserved.
*/
template <typename Derived>
class OptimizerCloneableOptions : public OptimizerOptions {
private:
std::unique_ptr<OptimizerOptions> clone() const override {
return std::make_unique<Derived>(static_cast<const Derived&>(*this));
}
// SFINAE field detection - detects optimizer fields using public accessor
// methods
template <class T, class Enable = void>
struct _has_lr : std::false_type {};
template <class T>
struct _has_lr<T, std::void_t<decltype(std::declval<const T&>().lr())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_momentum : std::false_type {};
template <class T>
struct _has_momentum<
T,
std::void_t<decltype(std::declval<const T&>().momentum())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_weight_decay : std::false_type {};
template <class T>
struct _has_weight_decay<
T,
std::void_t<decltype(std::declval<const T&>().weight_decay())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_dampening : std::false_type {};
template <class T>
struct _has_dampening<
T,
std::void_t<decltype(std::declval<const T&>().dampening())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_nesterov : std::false_type {};
template <class T>
struct _has_nesterov<
T,
std::void_t<decltype(std::declval<const T&>().nesterov())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_betas : std::false_type {};
template <class T>
struct _has_betas<T, std::void_t<decltype(std::declval<const T&>().betas())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_eps : std::false_type {};
template <class T>
struct _has_eps<T, std::void_t<decltype(std::declval<const T&>().eps())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_amsgrad : std::false_type {};
template <class T>
struct _has_amsgrad<
T,
std::void_t<decltype(std::declval<const T&>().amsgrad())>>
: std::true_type {};
// Optimizer-specific field detection
template <class T, class Enable = void>
struct _has_lr_decay : std::false_type {};
template <class T>
struct _has_lr_decay<
T,
std::void_t<decltype(std::declval<const T&>().lr_decay())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_alpha : std::false_type {};
template <class T>
struct _has_alpha<T, std::void_t<decltype(std::declval<const T&>().alpha())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_centered : std::false_type {};
template <class T>
struct _has_centered<
T,
std::void_t<decltype(std::declval<const T&>().centered())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_initial_accumulator_value : std::false_type {};
template <class T>
struct _has_initial_accumulator_value<
T,
std::void_t<
decltype(std::declval<const T&>().initial_accumulator_value())>>
: std::true_type {};
// LBFGS-specific fields with appropriate types
template <class T, class Enable = void>
struct _has_max_iter : std::false_type {};
template <class T>
struct _has_max_iter<
T,
std::void_t<decltype(std::declval<const T&>().max_iter())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_max_eval : std::false_type {};
template <class T>
struct _has_max_eval<
T,
std::void_t<decltype(std::declval<const T&>().max_eval())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_tolerance_grad : std::false_type {};
template <class T>
struct _has_tolerance_grad<
T,
std::void_t<decltype(std::declval<const T&>().tolerance_grad())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_tolerance_change : std::false_type {};
template <class T>
struct _has_tolerance_change<
T,
std::void_t<decltype(std::declval<const T&>().tolerance_change())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_history_size : std::false_type {};
template <class T>
struct _has_history_size<
T,
std::void_t<decltype(std::declval<const T&>().history_size())>>
: std::true_type {};
template <class T, class Enable = void>
struct _has_line_search_fn : std::false_type {};
template <class T>
struct _has_line_search_fn<
T,
std::void_t<decltype(std::declval<const T&>().line_search_fn())>>
: std::true_type {};
/**
* Merges user-specified options with optimizer defaults using
* constructor-default comparison to detect explicitly set fields.
*
* Algorithm:
* 1. Start with optimizer defaults as base
* 2. Create fresh constructor instance for comparison
* 3. If user_value != constructor_default → user explicitly set it → preserve
* 4. If user_value == constructor_default → user didn't set it → inherit from
* defaults
*/
void _merge_by_comparison(
const Derived& defaults,
const Derived& user_options) {
auto* result = static_cast<Derived*>(this);
*result = defaults; // Start with optimizer defaults
// Create constructor defaults instance for comparison
Derived constructor_defaults = []() {
if constexpr (std::is_default_constructible_v<Derived>) {
return Derived{};
} else {
// Handle optimizers requiring constructor parameters
if constexpr (std::is_same_v<Derived, SGDOptions>) {
return Derived(1e-3);
} else if constexpr (std::is_same_v<Derived, AdagradOptions>) {
return Derived(1e-2);
} else if constexpr (std::is_same_v<Derived, RMSpropOptions>) {
return Derived(1e-2);
} else if constexpr (std::is_same_v<Derived, LBFGSOptions>) {
return Derived(1);
} else {
return Derived{};
}
}
}();
// Merge fields: preserve user-set values, inherit defaults for unset values
if constexpr (_has_lr<Derived>::value) {
if (user_options.lr() != constructor_defaults.lr()) {
result->lr(user_options.lr());
}
}
if constexpr (_has_momentum<Derived>::value) {
if (user_options.momentum() != constructor_defaults.momentum()) {
result->momentum(user_options.momentum());
}
}
if constexpr (_has_weight_decay<Derived>::value) {
if (user_options.weight_decay() != constructor_defaults.weight_decay()) {
result->weight_decay(user_options.weight_decay());
}
}
if constexpr (_has_dampening<Derived>::value) {
if (user_options.dampening() != constructor_defaults.dampening()) {
result->dampening(user_options.dampening());
}
}
if constexpr (_has_nesterov<Derived>::value) {
if (user_options.nesterov() != constructor_defaults.nesterov()) {
result->nesterov(user_options.nesterov());
}
}
if constexpr (_has_betas<Derived>::value) {
if (user_options.betas() != constructor_defaults.betas()) {
result->betas(user_options.betas());
}
}
if constexpr (_has_eps<Derived>::value) {
if (user_options.eps() != constructor_defaults.eps()) {
result->eps(user_options.eps());
}
}
if constexpr (_has_amsgrad<Derived>::value) {
if (user_options.amsgrad() != constructor_defaults.amsgrad()) {
result->amsgrad(user_options.amsgrad());
}
}
// Optimizer-specific fields - automatically detected and handled
if constexpr (_has_lr_decay<Derived>::value) {
if (user_options.lr_decay() != constructor_defaults.lr_decay()) {
result->lr_decay(user_options.lr_decay());
}
}
if constexpr (_has_alpha<Derived>::value) {
if (user_options.alpha() != constructor_defaults.alpha()) {
result->alpha(user_options.alpha());
}
}
if constexpr (_has_centered<Derived>::value) {
if (user_options.centered() != constructor_defaults.centered()) {
result->centered(user_options.centered());
}
}
if constexpr (_has_initial_accumulator_value<Derived>::value) {
if (user_options.initial_accumulator_value() !=
constructor_defaults.initial_accumulator_value()) {
result->initial_accumulator_value(
user_options.initial_accumulator_value());
}
}
// LBFGS-specific fields with appropriate types
if constexpr (_has_max_iter<Derived>::value) {
if (user_options.max_iter() != constructor_defaults.max_iter()) {
result->max_iter(user_options.max_iter());
}
}
if constexpr (_has_max_eval<Derived>::value) {
if (user_options.max_eval() != constructor_defaults.max_eval()) {
result->max_eval(user_options.max_eval());
}
}
if constexpr (_has_tolerance_grad<Derived>::value) {
if (user_options.tolerance_grad() !=
constructor_defaults.tolerance_grad()) {
result->tolerance_grad(user_options.tolerance_grad());
}
}
if constexpr (_has_tolerance_change<Derived>::value) {
if (user_options.tolerance_change() !=
constructor_defaults.tolerance_change()) {
result->tolerance_change(user_options.tolerance_change());
}
}
if constexpr (_has_history_size<Derived>::value) {
if (user_options.history_size() != constructor_defaults.history_size()) {
result->history_size(user_options.history_size());
}
}
if constexpr (_has_line_search_fn<Derived>::value) {
if (user_options.line_search_fn() !=
constructor_defaults.line_search_fn()) {
result->line_search_fn(user_options.line_search_fn());
}
}
}
// Friend class for controlled access to private _merge_by_comparison method
friend class Optimizer;
};
/// Stores parameters in the param_group and stores a pointer to the
@ -507,43 +186,6 @@ class TORCH_API Optimizer {
/// Deserializes the optimizer state from the given `archive`.
virtual void load(serialize::InputArchive& archive);
private:
/// Helper function to try merging for a specific optimizer type
template <typename OptimizerType>
static bool _try_merge_optimizer_type(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults) {
auto* typed_final = dynamic_cast<OptimizerType*>(final_options.get());
auto* typed_user = dynamic_cast<const OptimizerType*>(&user_options);
auto* typed_defaults = dynamic_cast<const OptimizerType*>(&defaults);
if (typed_final && typed_user && typed_defaults) {
typed_final->_merge_by_comparison(*typed_defaults, *typed_user);
return true;
}
return false;
}
/// Simple variadic dispatch helper - try all optimizer types in one call
template <typename... OptimizerTypes>
static void _try_merge_all_optimizer_types(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults) {
// Try each optimizer type until one succeeds - much cleaner than manual
// chain
(void)(_try_merge_optimizer_type<OptimizerTypes>(
final_options, user_options, defaults) ||
...);
}
/// Convenience function with all known PyTorch optimizers
static void _try_merge_all_optimizers(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults);
protected:
std::vector<OptimizerParamGroup> param_groups_;
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;

View File

@ -3,35 +3,12 @@
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/types.h>
// Include complete type definitions for all optimizers to enable dynamic_cast
#include <torch/optim/adagrad.h>
#include <torch/optim/adam.h>
#include <torch/optim/adamw.h>
#include <torch/optim/lbfgs.h>
#include <torch/optim/rmsprop.h>
#include <torch/optim/sgd.h>
#include <string>
#include <utility>
#include <vector>
namespace torch::optim {
// Simple implementation using variadic template helper
void Optimizer::_try_merge_all_optimizers(
std::unique_ptr<OptimizerOptions>& final_options,
const OptimizerOptions& user_options,
const OptimizerOptions& defaults) {
// Clean one-liner replaces the entire repetitive dispatch chain
_try_merge_all_optimizer_types<
SGDOptions,
AdamOptions,
AdamWOptions,
AdagradOptions,
RMSpropOptions,
LBFGSOptions>(final_options, user_options, defaults);
}
bool OptimizerParamGroup::has_options() const {
return options_ != nullptr;
}
@ -124,20 +101,9 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
OptimizerParamGroup param_group_(param_group.params());
if (!param_group.has_options()) {
// No options provided - use defaults directly
param_group_.set_options(defaults_->clone());
} else {
// Options provided - merge user's explicit settings with defaults for
// parameter group inheritance This enables Python-C++ API parity by
// honoring user intent while inheriting missing parameters
auto final_options = defaults_->clone();
// Simple variadic dispatch - try all known optimizer types
_try_merge_all_optimizers(final_options, param_group.options(), *defaults_);
// If no merging was done (custom optimizer), final_options already contains
// defaults
param_group_.set_options(std::move(final_options));
param_group_.set_options(param_group.options().clone());
}
for (const auto& p : param_group_.params()) {
TORCH_CHECK(