mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add AdamW to C++ frontend (#40009)
Summary: Slightly modified Adam, following the python implementation, and the `ProducesPyTorchValues` tests pass. I had a problem with another test though (see commit c1a6241676ab84fc531c1c3a10f964aa5704092e), it seems that optimizing for two steps with the same optimizer vs optimizing for two steps using freshly initialized objects will produce the same output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40009 Differential Revision: D22096053 Pulled By: glaringlee fbshipit-source-id: a31a8f5488cb37c53752ddf15436efabdba67dc4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
89ef8f8141
commit
41f2dbde31
@ -538,6 +538,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/vision.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/adamw.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/lbfgs.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/optimizer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/rmsprop.cpp
|
||||
|
||||
@ -283,6 +283,31 @@ TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
|
||||
expected_parameters::Adam_with_weight_decay_and_amsgrad());
|
||||
}
|
||||
|
||||
TEST(OptimTest, XORConvergence_AdamW) {
|
||||
ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
|
||||
}
|
||||
|
||||
TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
|
||||
ASSERT_TRUE(test_optimizer_xor<AdamW>(
|
||||
AdamWOptions(0.1).amsgrad(true)));
|
||||
}
|
||||
|
||||
TEST(OptimTest, ProducesPyTorchValues_AdamW) {
|
||||
check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
|
||||
}
|
||||
|
||||
TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
|
||||
check_exact_values<AdamW>(
|
||||
AdamWOptions(1.0).weight_decay(0),
|
||||
expected_parameters::AdamW_without_weight_decay());
|
||||
}
|
||||
|
||||
TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
|
||||
check_exact_values<AdamW>(
|
||||
AdamWOptions(1.0).amsgrad(true),
|
||||
expected_parameters::AdamW_with_amsgrad());
|
||||
}
|
||||
|
||||
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
|
||||
check_exact_values<Adagrad>(
|
||||
AdagradOptions(1.0), expected_parameters::Adagrad());
|
||||
|
||||
@ -361,6 +361,219 @@ inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay_and_amsgra
|
||||
};
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<torch::Tensor>> AdamW() {
|
||||
return {
|
||||
{
|
||||
torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}),
|
||||
torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
|
||||
torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
|
||||
torch::tensor({-1.0704267290556988}),
|
||||
},
|
||||
{
|
||||
torch::tensor({3.3165329599188507, 3.223120441823618, 2.665544565239194, 2.6044341406663225, 2.479859063483047, 2.836831717112226}),
|
||||
torch::tensor({3.3885192024669744, 2.6544147219174556, 2.8709245656887328}),
|
||||
torch::tensor({-2.70172647102137, -2.836731459490802, -2.69652471546253}),
|
||||
torch::tensor({-2.575239255076019}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.231471944853865, 2.3549328325971755, 1.5699078054795328, 1.6160272935884685, 1.5339085081403547, 1.7397405105941612}),
|
||||
torch::tensor({2.8552579170807926, 1.8369866847839356, 1.9735168512425862}),
|
||||
torch::tensor({-2.6042083360293855, -2.6996673713262336, -1.8976087706977893}),
|
||||
torch::tensor({-1.6180915942867784}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.084688381515552, 2.3141612674892946, 1.4850714710140511, 1.5961047256668386, 1.440300645879787, 1.6065354941586025}),
|
||||
torch::tensor({3.0111385685659444, 1.955556497153507, 1.9596562467797627}),
|
||||
torch::tensor({-2.889337305884852, -2.965249100126337, -1.7721676671605975}),
|
||||
torch::tensor({-1.4001341655590005}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0465343456006604, 2.311613891239368, 1.4666717526896398, 1.601383980913499, 1.4223660595993763, 1.5711552625612757}),
|
||||
torch::tensor({3.07151984580744, 2.0112690538174802, 1.9592484602763875}),
|
||||
torch::tensor({-3.0186469726426863, -3.093855445542849, -1.7367953899738784}),
|
||||
torch::tensor({-1.3299011560804312}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.039659777412556, 2.3178034179536273, 1.4654302718412722, 1.6094701969162322, 1.4230510816446773, 1.565168902852383}),
|
||||
torch::tensor({3.1007583934270064, 2.039757113618415, 1.9652096140698696}),
|
||||
torch::tensor({-3.0880626664330832, -3.166705422245348, -1.73538367534238}),
|
||||
torch::tensor({-1.3130428735015893}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0413773043991963, 2.3251469369586366, 1.4690808101517236, 1.6174065798291044, 1.4280274009117935, 1.5682418226469732}),
|
||||
torch::tensor({3.118843540209399, 2.057729936485249, 1.9742319629710936}),
|
||||
torch::tensor({-3.1331019663177013, -3.2154332694373107, -1.7459831639793468}),
|
||||
torch::tensor({-1.3148644134154366}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0452604138357113, 2.332074253989847, 1.4738845773449165, 1.6246403004735728, 1.4335712611625357, 1.573826630920094}),
|
||||
torch::tensor({3.1324088069784093, 2.0711763619826575, 1.9841582498316732}),
|
||||
torch::tensor({-3.16737058959847, -3.2529206463859146, -1.7602788393925501}),
|
||||
torch::tensor({-1.32281766461531}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0495243704493262, 2.338413341249581, 1.4787599440132637, 1.631210274009555, 1.438849155552895, 1.5798736919537595}),
|
||||
torch::tensor({3.1438209015414227, 2.0823943437659658, 1.9940075805973108}),
|
||||
torch::tensor({-3.19628690363529, -3.2845941643030367, -1.7752333900055153}),
|
||||
torch::tensor({-1.332456718314933}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0536979895206295, 2.3442520601250334, 1.4834272584224222, 1.6372462654983486, 1.4437517398490174, 1.585780877834892}),
|
||||
torch::tensor({3.1540081461072447, 2.0923381262560454, 2.0034284957296107}),
|
||||
torch::tensor({-3.222142968201519, -3.312867602521477, -1.7898220261118043}),
|
||||
torch::tensor({-1.3422692037690986}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0576784836825315, 2.3496934395759377, 1.4878413407927933, 1.6428612479757005, 1.4483225979568104, 1.5914034339763325}),
|
||||
torch::tensor({3.163383747232199, 2.101446878895216, 2.012344413569353}),
|
||||
torch::tensor({-3.246000281299229, -3.338904166978488, -1.8037666936489785}),
|
||||
torch::tensor({-1.3517884775416527}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<torch::Tensor>> AdamW_without_weight_decay() {
|
||||
return {
|
||||
{
|
||||
torch::tensor({0.7890972864438476, 0.5024410688121617, 0.858707331305558, 0.6579707241208395, 0.7476356819075531, 1.6975564206516922}),
|
||||
torch::tensor({0.891467636010675, 0.70205134975675, 1.689201270942895}),
|
||||
torch::tensor({-1.0508030958460797, -1.3941351509567654, -1.284337577714353}),
|
||||
torch::tensor({-1.071138110298716}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233039313231831, 7.971150747377481, 6.643620950677599, 6.47097740790054, 6.170125488259256, 7.150739103343502}),
|
||||
torch::tensor({8.417695070103738, 6.597188212844593, 7.23175710827678}),
|
||||
torch::tensor({-6.729624357635757, -7.09743493108154, -6.753301896575352}),
|
||||
torch::tensor({-6.435639096011218}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424596059299, 7.971537360032308, 6.643920150720393, 6.471278075537239, 6.170405874224489, 7.151021086137983}),
|
||||
torch::tensor({8.418084791214298, 6.597493171180545, 7.232043740621598}),
|
||||
torch::tensor({-6.729918250724671, -7.097730102046093, -6.753584809755359}),
|
||||
torch::tensor({-6.4359165566974985}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424610557652, 7.971537374586563, 6.643920161995284, 6.471278086877828, 6.170405884785074, 7.151021096766406}),
|
||||
torch::tensor({8.418084805901906, 6.597493182713584, 7.2320437514477875}),
|
||||
torch::tensor({-6.72991829363266, -7.097730147102975, -6.753584838821182}),
|
||||
torch::tensor({-6.435916580217771}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424610575105, 7.971537374611125, 6.643920162027961, 6.471278086923277, 6.170405884809245, 7.151021096800041}),
|
||||
torch::tensor({8.418084805946393, 6.597493182796847, 7.232043751509309}),
|
||||
torch::tensor({-6.729918332327653, -7.097730188349552, -6.753584861205486}),
|
||||
torch::tensor({-6.435916596115672}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424610594861, 7.971537374639166, 6.64392016206557, 6.471278086975758, 6.170405884836981, 7.151021096838798}),
|
||||
torch::tensor({8.418084805997617, 6.59749318289335, 7.232043751580523}),
|
||||
torch::tensor({-6.72991837738045, -7.097730236373201, -6.753584887267492}),
|
||||
torch::tensor({-6.43591661462546}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424610617291, 7.971537374671012, 6.643920162108284, 6.471278087035361, 6.170405884868481, 7.151021096882812}),
|
||||
torch::tensor({8.418084806055798, 6.59749318300295, 7.232043751661401}),
|
||||
torch::tensor({-6.729918428547273, -7.09773029091405, -6.753584916866329}),
|
||||
torch::tensor({-6.4359166356471755}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424610642356, 7.9715373747065925, 6.643920162156006, 6.471278087101954, 6.1704058849036745, 7.15102109693199}),
|
||||
torch::tensor({8.418084806120802, 6.597493183125404, 7.232043751751764}),
|
||||
torch::tensor({-6.729918485714688, -7.0977303518511805, -6.753584949936365}),
|
||||
torch::tensor({-6.43591665913422}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424610670038, 7.97153737474589, 6.643920162208714, 6.471278087175501, 6.170405884942545, 7.151021096986303}),
|
||||
torch::tensor({8.418084806192596, 6.597493183260647, 7.232043751851564}),
|
||||
torch::tensor({-6.729918548853505, -7.097730419153473, -6.753584986460725}),
|
||||
torch::tensor({-6.435916685074594}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.233424610700352, 7.971537374788922, 6.643920162266432, 6.47127808725604, 6.17040588498511, 7.1510210970457795}),
|
||||
torch::tensor({8.418084806271217, 6.597493183408747, 7.232043751960854}),
|
||||
torch::tensor({-6.7299186179943, -7.097730492853521, -6.753585026457088}),
|
||||
torch::tensor({-6.435916713480863}),
|
||||
},
|
||||
{
|
||||
torch::tensor({8.23342461073333, 7.971537374835737, 6.643920162329224, 6.471278087343658, 6.170405885031416, 7.151021097110484}),
|
||||
torch::tensor({8.418084806356747, 6.597493183569867, 7.232043752079749}),
|
||||
torch::tensor({-6.729918693213275, -7.097730573032567, -6.753585069969552}),
|
||||
torch::tensor({-6.43591674438434}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<torch::Tensor>> AdamW_with_amsgrad() {
|
||||
return {
|
||||
{
|
||||
torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}),
|
||||
torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
|
||||
torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
|
||||
torch::tensor({-1.0704267290556988}),
|
||||
},
|
||||
{
|
||||
torch::tensor({3.3017259270507915, 3.2082991753694565, 2.653930978510442, 2.5927674339810585, 2.4689608790182933, 2.825873703467739}),
|
||||
torch::tensor({3.373698198112671, 2.6425942964586664, 2.8597930424244304}),
|
||||
torch::tensor({-2.690360632302962, -2.8253191596069525, -2.6855499873057473}),
|
||||
torch::tensor({-2.5644658591929406}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.222607725541013, 2.3447188854637004, 1.5614270655258826, 1.606610018462357, 1.5260497191448619, 1.7309643622674138}),
|
||||
torch::tensor({2.84137462783552, 1.824806600633721, 1.9620493659996037}),
|
||||
torch::tensor({-2.576642773625787, -2.6706153846815766, -1.8799876863754623}),
|
||||
torch::tensor({-1.6044722984810953}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0739558768648205, 2.3008338863863496, 1.4738888208638767, 1.5829485271829449, 1.4296176764284294, 1.5939984909850073}),
|
||||
torch::tensor({2.9908013612792415, 1.936590940953305, 1.941691630199464}),
|
||||
torch::tensor({-2.846562884997548, -2.9195962101501203, -1.746484716887341}),
|
||||
torch::tensor({-1.381525131003179}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0333926094256953, 2.294977109171754, 1.452870514716895, 1.584853677999522, 1.4086299433181402, 1.5548201727855224}),
|
||||
torch::tensor({3.0454817801193976, 1.9867169062383696, 1.935312753106444}),
|
||||
torch::tensor({-2.9612116762746394, -3.0322275992001084, -1.7026114905180725}),
|
||||
torch::tensor({-1.30563541393247}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.02392417168201, 2.2977279587859, 1.4488511120131309, 1.5894646930743725, 1.406253686759073, 1.5450647949022756}),
|
||||
torch::tensor({3.069025602955343, 2.0096872138967488, 1.935438546309299}),
|
||||
torch::tensor({-3.016103148166836, -3.0893062953033583, -1.6925290685615872}),
|
||||
torch::tensor({-1.282870120405012}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0230257817348316, 2.3016167065040647, 1.4496901978629444, 1.5939034289777392, 1.4082421794430946, 1.5444538756003756}),
|
||||
torch::tensor({3.0814016132787954, 2.022150201844143, 1.9387429991308658}),
|
||||
torch::tensor({-3.0466485946438406, -3.1223144611322446, -1.6944083009127773}),
|
||||
torch::tensor({-1.2786980736911064}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.024305922065404, 2.305101549510461, 1.4516863420588493, 1.5976447954882376, 1.4108596097183552, 1.5464236284303425}),
|
||||
torch::tensor({3.0892574233545065, 2.0300944858242236, 1.943040321845021}),
|
||||
torch::tensor({-3.0664189567897306, -3.1440820888166425, -1.6999750448893618}),
|
||||
torch::tensor({-1.2806281811826203}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0259854116237364, 2.3080169152218293, 1.4537680296813915, 1.6007369432392426, 1.4132529823064277, 1.5489035046525346}),
|
||||
torch::tensor({3.0949717180851337, 2.0358251379915764, 1.9473249654893}),
|
||||
torch::tensor({-3.0808231377905426, -3.160021699873689, -1.7062031001273494}),
|
||||
torch::tensor({-1.28423401369705}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.0275923948348638, 2.3104512601389637, 1.455657715721078, 1.6033123357613526, 1.4153003204463288, 1.5512775896116622}),
|
||||
torch::tensor({3.099479021299846, 2.0403012223048775, 1.9512285931847464}),
|
||||
torch::tensor({-3.092151979336299, -3.1725453680885267, -1.7120689614428697}),
|
||||
torch::tensor({-1.2880095517062655}),
|
||||
},
|
||||
{
|
||||
torch::tensor({2.029022468328371, 2.3125066985045892, 1.4573100228823295, 1.605484259933419, 1.417038245960655, 1.5533932056240227}),
|
||||
torch::tensor({3.103195011518616, 2.043964003458376, 1.9546640840748621}),
|
||||
torch::tensor({-3.1014680843131184, -3.1828179298513968, -1.7172933346797972}),
|
||||
torch::tensor({-1.2914899987134136}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<torch::Tensor>> Adagrad() {
|
||||
return {
|
||||
{
|
||||
|
||||
@ -26,6 +26,9 @@ OPTIMIZERS = {
|
||||
"Adam": lambda p: torch.optim.Adam(p, 1.0),
|
||||
"Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2),
|
||||
"Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-6, amsgrad=True),
|
||||
"AdamW": lambda p: torch.optim.AdamW(p, 1.0),
|
||||
"AdamW_without_weight_decay": lambda p: torch.optim.AdamW(p, 1.0, weight_decay=0),
|
||||
"AdamW_with_amsgrad": lambda p: torch.optim.AdamW(p, 1.0, amsgrad=True),
|
||||
"Adagrad": lambda p: torch.optim.Adagrad(p, 1.0),
|
||||
"Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-2),
|
||||
"Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-6, lr_decay=1e-3),
|
||||
|
||||
@ -64,6 +64,7 @@ void is_optimizer_state_equal(
|
||||
|
||||
template <typename OptimizerClass, typename DerivedOptimizerOptions, typename DerivedOptimizerParamState>
|
||||
void test_serialize_optimizer(DerivedOptimizerOptions options, bool only_has_global_state = false) {
|
||||
torch::manual_seed(0);
|
||||
auto model1 = Linear(5, 2);
|
||||
auto model2 = Linear(5, 2);
|
||||
auto model3 = Linear(5, 2);
|
||||
@ -600,6 +601,56 @@ TEST(SerializeTest, Optim_Adam) {
|
||||
is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
|
||||
}
|
||||
|
||||
TEST(SerializeTest, Optim_AdamW) {
|
||||
test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(AdamWOptions().lr(0.99999).amsgrad(true).betas(std::make_tuple(0.999, 0.1)));
|
||||
|
||||
// bc compatibility check
|
||||
auto model1 = Linear(5, 2);
|
||||
auto model1_params = model1->parameters();
|
||||
// added a tensor for lazy init check - when all params do not have entry in buffers
|
||||
model1_params.emplace_back(torch::randn({2,3}));
|
||||
auto optim1 = torch::optim::AdamW(model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
|
||||
|
||||
auto x = torch::ones({10, 5});
|
||||
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
||||
optimizer.zero_grad();
|
||||
auto y = model->forward(x).sum();
|
||||
y.backward();
|
||||
optimizer.step();
|
||||
};
|
||||
step(optim1, model1);
|
||||
|
||||
std::vector<int64_t> step_buffers;
|
||||
std::vector<at::Tensor> exp_average_buffers;
|
||||
std::vector<at::Tensor> exp_average_sq_buffers;
|
||||
std::vector<at::Tensor> max_exp_average_sq_buffers;
|
||||
const auto& params_ = optim1.param_groups()[0].params();
|
||||
const auto& optim1_state = optim1.state();
|
||||
for (size_t i = 0; i < params_.size(); i++) {
|
||||
if(i != (params_.size() - 1)) {
|
||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
||||
const AdamWParamState& curr_state_ = static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
|
||||
step_buffers.emplace_back(curr_state_.step());
|
||||
exp_average_buffers.emplace_back(curr_state_.exp_avg());
|
||||
exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
|
||||
if(curr_state_.max_exp_avg_sq().defined()) {
|
||||
max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
|
||||
}
|
||||
}
|
||||
}
|
||||
// write buffers to the file
|
||||
auto optim_tempfile_old_format = c10::make_tempfile();
|
||||
torch::serialize::OutputArchive output_archive;
|
||||
write_step_buffers(output_archive, "step_buffers", step_buffers);
|
||||
write_tensors_to_archive(output_archive, "exp_average_buffers", exp_average_buffers);
|
||||
write_tensors_to_archive(output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
|
||||
write_tensors_to_archive(output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
|
||||
output_archive.save_to(optim_tempfile_old_format.name);
|
||||
auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
|
||||
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name);
|
||||
is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
|
||||
}
|
||||
|
||||
TEST(SerializeTest, Optim_RMSprop) {
|
||||
auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
|
||||
test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
|
||||
|
||||
@ -384,6 +384,7 @@ torch_cpp_srcs = [
|
||||
"torch/csrc/api/src/nn/options/vision.cpp",
|
||||
"torch/csrc/api/src/optim/adagrad.cpp",
|
||||
"torch/csrc/api/src/optim/adam.cpp",
|
||||
"torch/csrc/api/src/optim/adamw.cpp",
|
||||
"torch/csrc/api/src/optim/lbfgs.cpp",
|
||||
"torch/csrc/api/src/optim/optimizer.cpp",
|
||||
"torch/csrc/api/src/optim/rmsprop.cpp",
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/optim/adagrad.h>
|
||||
#include <torch/optim/adam.h>
|
||||
#include <torch/optim/adamw.h>
|
||||
#include <torch/optim/lbfgs.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/rmsprop.h>
|
||||
|
||||
75
torch/csrc/api/include/torch/optim/adamw.h
Normal file
75
torch/csrc/api/include/torch/optim/adamw.h
Normal file
@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/arg.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
class OutputArchive;
|
||||
class InputArchive;
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
|
||||
struct TORCH_API AdamWOptions : public OptimizerCloneableOptions<AdamWOptions> {
|
||||
AdamWOptions(double lr = 1e-3);
|
||||
TORCH_ARG(double, lr) = 1e-3;
|
||||
typedef std::tuple<double, double> betas_t;
|
||||
TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999);
|
||||
TORCH_ARG(double, eps) = 1e-8;
|
||||
TORCH_ARG(double, weight_decay) = 1e-2;
|
||||
TORCH_ARG(bool, amsgrad) = false;
|
||||
public:
|
||||
void serialize(torch::serialize::InputArchive& archive) override;
|
||||
void serialize(torch::serialize::OutputArchive& archive) const override;
|
||||
TORCH_API friend bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs);
|
||||
~AdamWOptions() = default;
|
||||
};
|
||||
|
||||
struct TORCH_API AdamWParamState : public OptimizerCloneableParamState<AdamWParamState> {
|
||||
TORCH_ARG(int64_t, step) = 0;
|
||||
TORCH_ARG(torch::Tensor, exp_avg);
|
||||
TORCH_ARG(torch::Tensor, exp_avg_sq);
|
||||
TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {};
|
||||
|
||||
public:
|
||||
void serialize(torch::serialize::InputArchive& archive) override;
|
||||
void serialize(torch::serialize::OutputArchive& archive) const override;
|
||||
TORCH_API friend bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs);
|
||||
~AdamWParamState() = default;
|
||||
};
|
||||
|
||||
class TORCH_API AdamW : public Optimizer {
|
||||
public:
|
||||
explicit AdamW(std::vector<OptimizerParamGroup> param_groups,
|
||||
AdamWOptions defaults = {}) : Optimizer(std::move(param_groups), std::make_unique<AdamWOptions>(defaults)) {
|
||||
TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr());
|
||||
TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps());
|
||||
auto betas = defaults.betas();
|
||||
TORCH_CHECK(0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, "Invalid beta parameter at index 0: ", std::get<0>(betas));
|
||||
TORCH_CHECK(0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, "Invalid beta parameter at index 1: ", std::get<1>(betas));
|
||||
TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay());
|
||||
}
|
||||
explicit AdamW(
|
||||
std::vector<Tensor> params,
|
||||
AdamWOptions defaults = {}) : AdamW({std::move(OptimizerParamGroup(params))}, defaults) {}
|
||||
|
||||
torch::Tensor step(LossClosure closure = nullptr) override;
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
private:
|
||||
template <typename Self, typename Archive>
|
||||
static void serialize(Self& self, Archive& archive) {
|
||||
_TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(AdamW);
|
||||
}
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
170
torch/csrc/api/src/optim/adamw.cpp
Normal file
170
torch/csrc/api/src/optim/adamw.cpp
Normal file
@ -0,0 +1,170 @@
|
||||
#include <torch/optim/adamw.h>
|
||||
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
|
||||
AdamWOptions::AdamWOptions(double lr) : lr_(lr) {}
|
||||
|
||||
bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs) {
|
||||
return (lhs.lr() == rhs.lr()) &&
|
||||
(std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) &&
|
||||
(std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) &&
|
||||
(lhs.eps() == rhs.eps()) &&
|
||||
(lhs.weight_decay() == rhs.weight_decay() &&
|
||||
(lhs.amsgrad() == rhs.amsgrad()));
|
||||
}
|
||||
|
||||
void AdamWOptions::serialize(torch::serialize::OutputArchive& archive) const {
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(amsgrad);
|
||||
}
|
||||
|
||||
void AdamWOptions::serialize(torch::serialize::InputArchive& archive) {
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, amsgrad);
|
||||
}
|
||||
|
||||
bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs) {
|
||||
return (lhs.step() == rhs.step()) &&
|
||||
torch::equal(lhs.exp_avg(), rhs.exp_avg()) &&
|
||||
torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) &&
|
||||
torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq());
|
||||
}
|
||||
|
||||
void AdamWParamState::serialize(torch::serialize::OutputArchive& archive) const {
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
|
||||
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_exp_avg_sq);
|
||||
}
|
||||
|
||||
void AdamWParamState::serialize(torch::serialize::InputArchive& archive) {
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg);
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg_sq);
|
||||
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq);
|
||||
}
|
||||
|
||||
Tensor AdamW::step(LossClosure closure) {
|
||||
NoGradGuard no_grad;
|
||||
Tensor loss = {};
|
||||
if (closure != nullptr) {
|
||||
at::AutoGradMode enable_grad(true);
|
||||
loss = closure();
|
||||
}
|
||||
for (auto& group : param_groups_) {
|
||||
for (auto& p : group.params()) {
|
||||
if (!p.grad().defined()) {
|
||||
continue;
|
||||
}
|
||||
auto grad = p.grad();
|
||||
TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients"/*, please consider SparseAdamW instead*/);
|
||||
auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
auto& options = static_cast<AdamWOptions&>(group.options());
|
||||
|
||||
// Perform stepweight decay
|
||||
if(options.weight_decay() != 0) {
|
||||
p.mul_(1 - options.lr() * options.weight_decay());
|
||||
}
|
||||
|
||||
// State initialization
|
||||
if(param_state == state_.end()) {
|
||||
auto state = std::make_unique<AdamWParamState>();
|
||||
state->step(0);
|
||||
// Exponential moving average of gradient values
|
||||
state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||
// Exponential moving average of squared gradient values
|
||||
state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||
if(options.amsgrad()) {
|
||||
// Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||
}
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state);
|
||||
}
|
||||
|
||||
auto& state = static_cast<AdamWParamState&>(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
||||
auto& exp_avg = state.exp_avg();
|
||||
auto& exp_avg_sq = state.exp_avg_sq();
|
||||
auto& max_exp_avg_sq = state.max_exp_avg_sq();
|
||||
|
||||
state.step(state.step()+1);
|
||||
auto beta1 = std::get<0>(options.betas());
|
||||
auto beta2 = std::get<1>(options.betas());
|
||||
|
||||
auto bias_correction1 = 1 - std::pow(beta1, state.step());
|
||||
auto bias_correction2 = 1 - std::pow(beta2, state.step());
|
||||
|
||||
// Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, 1 - beta1);
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2);
|
||||
|
||||
Tensor denom;
|
||||
if(options.amsgrad()) {
|
||||
// Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq);
|
||||
// Use the max. for normalizing running avg. of gradient
|
||||
denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
|
||||
} else {
|
||||
denom = (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
|
||||
}
|
||||
|
||||
auto step_size = options.lr() / bias_correction1;
|
||||
p.addcdiv_(exp_avg, denom, -step_size);
|
||||
}
|
||||
}
|
||||
return loss;
|
||||
}
|
||||
|
||||
void AdamW::save(serialize::OutputArchive& archive) const {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
|
||||
void AdamW::load(serialize::InputArchive& archive) {
|
||||
IValue pytorch_version;
|
||||
if (archive.try_read("pytorch_version", pytorch_version)) {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
else { // deserializing archives saved in old format (prior to version 1.5.0)
|
||||
TORCH_WARN(
|
||||
"Your serialized AdamW optimizer is still using the old serialization format. "
|
||||
"You should re-save your AdamW optimizer to use the new serialization format.");
|
||||
std::vector<int64_t> step_buffers;
|
||||
std::vector<at::Tensor> exp_average_buffers;
|
||||
std::vector<at::Tensor> exp_average_sq_buffers;
|
||||
std::vector<at::Tensor> max_exp_average_sq_buffers;
|
||||
torch::optim::serialize(archive, "step_buffers", step_buffers);
|
||||
torch::optim::serialize(archive, "exp_average_buffers", exp_average_buffers);
|
||||
torch::optim::serialize(archive, "exp_average_sq_buffers", exp_average_sq_buffers);
|
||||
torch::optim::serialize(archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
|
||||
// since there were no param_groups prior to version 1.5.0, assuming all tensors are now in one param_group
|
||||
std::vector<Tensor> params = param_groups_.at(0).params();
|
||||
for (size_t idx = 0; idx < step_buffers.size(); idx++) {
|
||||
auto state = std::make_unique<AdamWParamState>();
|
||||
state->step(step_buffers.at(idx));
|
||||
state->exp_avg(exp_average_buffers.at(idx));
|
||||
state->exp_avg_sq(exp_average_sq_buffers.at(idx));
|
||||
if (idx < max_exp_average_sq_buffers.size()) {
|
||||
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
|
||||
}
|
||||
state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] = std::move(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
Reference in New Issue
Block a user