Add AdamW to C++ frontend (#40009)

Summary:
Slightly modified Adam, following the python implementation, and the `ProducesPyTorchValues` tests pass. I had a problem with another test though (see commit c1a6241676ab84fc531c1c3a10f964aa5704092e), it seems that optimizing for two steps with the same optimizer vs optimizing for two steps using freshly initialized objects will produce the same output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40009

Differential Revision: D22096053

Pulled By: glaringlee

fbshipit-source-id: a31a8f5488cb37c53752ddf15436efabdba67dc4
This commit is contained in:
Sotiris Lamprinidis
2020-06-18 15:26:21 -07:00
committed by Facebook GitHub Bot
parent 89ef8f8141
commit 41f2dbde31
9 changed files with 540 additions and 0 deletions

View File

@ -538,6 +538,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/nn/options/vision.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adamw.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/lbfgs.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/optimizer.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/rmsprop.cpp

View File

@ -283,6 +283,31 @@ TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
expected_parameters::Adam_with_weight_decay_and_amsgrad());
}
TEST(OptimTest, XORConvergence_AdamW) {
ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
}
TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
ASSERT_TRUE(test_optimizer_xor<AdamW>(
AdamWOptions(0.1).amsgrad(true)));
}
TEST(OptimTest, ProducesPyTorchValues_AdamW) {
check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
}
TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
check_exact_values<AdamW>(
AdamWOptions(1.0).weight_decay(0),
expected_parameters::AdamW_without_weight_decay());
}
TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
check_exact_values<AdamW>(
AdamWOptions(1.0).amsgrad(true),
expected_parameters::AdamW_with_amsgrad());
}
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
check_exact_values<Adagrad>(
AdagradOptions(1.0), expected_parameters::Adagrad());

View File

@ -361,6 +361,219 @@ inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay_and_amsgra
};
}
inline std::vector<std::vector<torch::Tensor>> AdamW() {
return {
{
torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}),
torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
torch::tensor({-1.0704267290556988}),
},
{
torch::tensor({3.3165329599188507, 3.223120441823618, 2.665544565239194, 2.6044341406663225, 2.479859063483047, 2.836831717112226}),
torch::tensor({3.3885192024669744, 2.6544147219174556, 2.8709245656887328}),
torch::tensor({-2.70172647102137, -2.836731459490802, -2.69652471546253}),
torch::tensor({-2.575239255076019}),
},
{
torch::tensor({2.231471944853865, 2.3549328325971755, 1.5699078054795328, 1.6160272935884685, 1.5339085081403547, 1.7397405105941612}),
torch::tensor({2.8552579170807926, 1.8369866847839356, 1.9735168512425862}),
torch::tensor({-2.6042083360293855, -2.6996673713262336, -1.8976087706977893}),
torch::tensor({-1.6180915942867784}),
},
{
torch::tensor({2.084688381515552, 2.3141612674892946, 1.4850714710140511, 1.5961047256668386, 1.440300645879787, 1.6065354941586025}),
torch::tensor({3.0111385685659444, 1.955556497153507, 1.9596562467797627}),
torch::tensor({-2.889337305884852, -2.965249100126337, -1.7721676671605975}),
torch::tensor({-1.4001341655590005}),
},
{
torch::tensor({2.0465343456006604, 2.311613891239368, 1.4666717526896398, 1.601383980913499, 1.4223660595993763, 1.5711552625612757}),
torch::tensor({3.07151984580744, 2.0112690538174802, 1.9592484602763875}),
torch::tensor({-3.0186469726426863, -3.093855445542849, -1.7367953899738784}),
torch::tensor({-1.3299011560804312}),
},
{
torch::tensor({2.039659777412556, 2.3178034179536273, 1.4654302718412722, 1.6094701969162322, 1.4230510816446773, 1.565168902852383}),
torch::tensor({3.1007583934270064, 2.039757113618415, 1.9652096140698696}),
torch::tensor({-3.0880626664330832, -3.166705422245348, -1.73538367534238}),
torch::tensor({-1.3130428735015893}),
},
{
torch::tensor({2.0413773043991963, 2.3251469369586366, 1.4690808101517236, 1.6174065798291044, 1.4280274009117935, 1.5682418226469732}),
torch::tensor({3.118843540209399, 2.057729936485249, 1.9742319629710936}),
torch::tensor({-3.1331019663177013, -3.2154332694373107, -1.7459831639793468}),
torch::tensor({-1.3148644134154366}),
},
{
torch::tensor({2.0452604138357113, 2.332074253989847, 1.4738845773449165, 1.6246403004735728, 1.4335712611625357, 1.573826630920094}),
torch::tensor({3.1324088069784093, 2.0711763619826575, 1.9841582498316732}),
torch::tensor({-3.16737058959847, -3.2529206463859146, -1.7602788393925501}),
torch::tensor({-1.32281766461531}),
},
{
torch::tensor({2.0495243704493262, 2.338413341249581, 1.4787599440132637, 1.631210274009555, 1.438849155552895, 1.5798736919537595}),
torch::tensor({3.1438209015414227, 2.0823943437659658, 1.9940075805973108}),
torch::tensor({-3.19628690363529, -3.2845941643030367, -1.7752333900055153}),
torch::tensor({-1.332456718314933}),
},
{
torch::tensor({2.0536979895206295, 2.3442520601250334, 1.4834272584224222, 1.6372462654983486, 1.4437517398490174, 1.585780877834892}),
torch::tensor({3.1540081461072447, 2.0923381262560454, 2.0034284957296107}),
torch::tensor({-3.222142968201519, -3.312867602521477, -1.7898220261118043}),
torch::tensor({-1.3422692037690986}),
},
{
torch::tensor({2.0576784836825315, 2.3496934395759377, 1.4878413407927933, 1.6428612479757005, 1.4483225979568104, 1.5914034339763325}),
torch::tensor({3.163383747232199, 2.101446878895216, 2.012344413569353}),
torch::tensor({-3.246000281299229, -3.338904166978488, -1.8037666936489785}),
torch::tensor({-1.3517884775416527}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> AdamW_without_weight_decay() {
return {
{
torch::tensor({0.7890972864438476, 0.5024410688121617, 0.858707331305558, 0.6579707241208395, 0.7476356819075531, 1.6975564206516922}),
torch::tensor({0.891467636010675, 0.70205134975675, 1.689201270942895}),
torch::tensor({-1.0508030958460797, -1.3941351509567654, -1.284337577714353}),
torch::tensor({-1.071138110298716}),
},
{
torch::tensor({8.233039313231831, 7.971150747377481, 6.643620950677599, 6.47097740790054, 6.170125488259256, 7.150739103343502}),
torch::tensor({8.417695070103738, 6.597188212844593, 7.23175710827678}),
torch::tensor({-6.729624357635757, -7.09743493108154, -6.753301896575352}),
torch::tensor({-6.435639096011218}),
},
{
torch::tensor({8.233424596059299, 7.971537360032308, 6.643920150720393, 6.471278075537239, 6.170405874224489, 7.151021086137983}),
torch::tensor({8.418084791214298, 6.597493171180545, 7.232043740621598}),
torch::tensor({-6.729918250724671, -7.097730102046093, -6.753584809755359}),
torch::tensor({-6.4359165566974985}),
},
{
torch::tensor({8.233424610557652, 7.971537374586563, 6.643920161995284, 6.471278086877828, 6.170405884785074, 7.151021096766406}),
torch::tensor({8.418084805901906, 6.597493182713584, 7.2320437514477875}),
torch::tensor({-6.72991829363266, -7.097730147102975, -6.753584838821182}),
torch::tensor({-6.435916580217771}),
},
{
torch::tensor({8.233424610575105, 7.971537374611125, 6.643920162027961, 6.471278086923277, 6.170405884809245, 7.151021096800041}),
torch::tensor({8.418084805946393, 6.597493182796847, 7.232043751509309}),
torch::tensor({-6.729918332327653, -7.097730188349552, -6.753584861205486}),
torch::tensor({-6.435916596115672}),
},
{
torch::tensor({8.233424610594861, 7.971537374639166, 6.64392016206557, 6.471278086975758, 6.170405884836981, 7.151021096838798}),
torch::tensor({8.418084805997617, 6.59749318289335, 7.232043751580523}),
torch::tensor({-6.72991837738045, -7.097730236373201, -6.753584887267492}),
torch::tensor({-6.43591661462546}),
},
{
torch::tensor({8.233424610617291, 7.971537374671012, 6.643920162108284, 6.471278087035361, 6.170405884868481, 7.151021096882812}),
torch::tensor({8.418084806055798, 6.59749318300295, 7.232043751661401}),
torch::tensor({-6.729918428547273, -7.09773029091405, -6.753584916866329}),
torch::tensor({-6.4359166356471755}),
},
{
torch::tensor({8.233424610642356, 7.9715373747065925, 6.643920162156006, 6.471278087101954, 6.1704058849036745, 7.15102109693199}),
torch::tensor({8.418084806120802, 6.597493183125404, 7.232043751751764}),
torch::tensor({-6.729918485714688, -7.0977303518511805, -6.753584949936365}),
torch::tensor({-6.43591665913422}),
},
{
torch::tensor({8.233424610670038, 7.97153737474589, 6.643920162208714, 6.471278087175501, 6.170405884942545, 7.151021096986303}),
torch::tensor({8.418084806192596, 6.597493183260647, 7.232043751851564}),
torch::tensor({-6.729918548853505, -7.097730419153473, -6.753584986460725}),
torch::tensor({-6.435916685074594}),
},
{
torch::tensor({8.233424610700352, 7.971537374788922, 6.643920162266432, 6.47127808725604, 6.17040588498511, 7.1510210970457795}),
torch::tensor({8.418084806271217, 6.597493183408747, 7.232043751960854}),
torch::tensor({-6.7299186179943, -7.097730492853521, -6.753585026457088}),
torch::tensor({-6.435916713480863}),
},
{
torch::tensor({8.23342461073333, 7.971537374835737, 6.643920162329224, 6.471278087343658, 6.170405885031416, 7.151021097110484}),
torch::tensor({8.418084806356747, 6.597493183569867, 7.232043752079749}),
torch::tensor({-6.729918693213275, -7.097730573032567, -6.753585069969552}),
torch::tensor({-6.43591674438434}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> AdamW_with_amsgrad() {
return {
{
torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}),
torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
torch::tensor({-1.0704267290556988}),
},
{
torch::tensor({3.3017259270507915, 3.2082991753694565, 2.653930978510442, 2.5927674339810585, 2.4689608790182933, 2.825873703467739}),
torch::tensor({3.373698198112671, 2.6425942964586664, 2.8597930424244304}),
torch::tensor({-2.690360632302962, -2.8253191596069525, -2.6855499873057473}),
torch::tensor({-2.5644658591929406}),
},
{
torch::tensor({2.222607725541013, 2.3447188854637004, 1.5614270655258826, 1.606610018462357, 1.5260497191448619, 1.7309643622674138}),
torch::tensor({2.84137462783552, 1.824806600633721, 1.9620493659996037}),
torch::tensor({-2.576642773625787, -2.6706153846815766, -1.8799876863754623}),
torch::tensor({-1.6044722984810953}),
},
{
torch::tensor({2.0739558768648205, 2.3008338863863496, 1.4738888208638767, 1.5829485271829449, 1.4296176764284294, 1.5939984909850073}),
torch::tensor({2.9908013612792415, 1.936590940953305, 1.941691630199464}),
torch::tensor({-2.846562884997548, -2.9195962101501203, -1.746484716887341}),
torch::tensor({-1.381525131003179}),
},
{
torch::tensor({2.0333926094256953, 2.294977109171754, 1.452870514716895, 1.584853677999522, 1.4086299433181402, 1.5548201727855224}),
torch::tensor({3.0454817801193976, 1.9867169062383696, 1.935312753106444}),
torch::tensor({-2.9612116762746394, -3.0322275992001084, -1.7026114905180725}),
torch::tensor({-1.30563541393247}),
},
{
torch::tensor({2.02392417168201, 2.2977279587859, 1.4488511120131309, 1.5894646930743725, 1.406253686759073, 1.5450647949022756}),
torch::tensor({3.069025602955343, 2.0096872138967488, 1.935438546309299}),
torch::tensor({-3.016103148166836, -3.0893062953033583, -1.6925290685615872}),
torch::tensor({-1.282870120405012}),
},
{
torch::tensor({2.0230257817348316, 2.3016167065040647, 1.4496901978629444, 1.5939034289777392, 1.4082421794430946, 1.5444538756003756}),
torch::tensor({3.0814016132787954, 2.022150201844143, 1.9387429991308658}),
torch::tensor({-3.0466485946438406, -3.1223144611322446, -1.6944083009127773}),
torch::tensor({-1.2786980736911064}),
},
{
torch::tensor({2.024305922065404, 2.305101549510461, 1.4516863420588493, 1.5976447954882376, 1.4108596097183552, 1.5464236284303425}),
torch::tensor({3.0892574233545065, 2.0300944858242236, 1.943040321845021}),
torch::tensor({-3.0664189567897306, -3.1440820888166425, -1.6999750448893618}),
torch::tensor({-1.2806281811826203}),
},
{
torch::tensor({2.0259854116237364, 2.3080169152218293, 1.4537680296813915, 1.6007369432392426, 1.4132529823064277, 1.5489035046525346}),
torch::tensor({3.0949717180851337, 2.0358251379915764, 1.9473249654893}),
torch::tensor({-3.0808231377905426, -3.160021699873689, -1.7062031001273494}),
torch::tensor({-1.28423401369705}),
},
{
torch::tensor({2.0275923948348638, 2.3104512601389637, 1.455657715721078, 1.6033123357613526, 1.4153003204463288, 1.5512775896116622}),
torch::tensor({3.099479021299846, 2.0403012223048775, 1.9512285931847464}),
torch::tensor({-3.092151979336299, -3.1725453680885267, -1.7120689614428697}),
torch::tensor({-1.2880095517062655}),
},
{
torch::tensor({2.029022468328371, 2.3125066985045892, 1.4573100228823295, 1.605484259933419, 1.417038245960655, 1.5533932056240227}),
torch::tensor({3.103195011518616, 2.043964003458376, 1.9546640840748621}),
torch::tensor({-3.1014680843131184, -3.1828179298513968, -1.7172933346797972}),
torch::tensor({-1.2914899987134136}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> Adagrad() {
return {
{

View File

@ -26,6 +26,9 @@ OPTIMIZERS = {
"Adam": lambda p: torch.optim.Adam(p, 1.0),
"Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2),
"Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-6, amsgrad=True),
"AdamW": lambda p: torch.optim.AdamW(p, 1.0),
"AdamW_without_weight_decay": lambda p: torch.optim.AdamW(p, 1.0, weight_decay=0),
"AdamW_with_amsgrad": lambda p: torch.optim.AdamW(p, 1.0, amsgrad=True),
"Adagrad": lambda p: torch.optim.Adagrad(p, 1.0),
"Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-2),
"Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-6, lr_decay=1e-3),

View File

@ -64,6 +64,7 @@ void is_optimizer_state_equal(
template <typename OptimizerClass, typename DerivedOptimizerOptions, typename DerivedOptimizerParamState>
void test_serialize_optimizer(DerivedOptimizerOptions options, bool only_has_global_state = false) {
torch::manual_seed(0);
auto model1 = Linear(5, 2);
auto model2 = Linear(5, 2);
auto model3 = Linear(5, 2);
@ -600,6 +601,56 @@ TEST(SerializeTest, Optim_Adam) {
is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
}
TEST(SerializeTest, Optim_AdamW) {
test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(AdamWOptions().lr(0.99999).amsgrad(true).betas(std::make_tuple(0.999, 0.1)));
// bc compatibility check
auto model1 = Linear(5, 2);
auto model1_params = model1->parameters();
// added a tensor for lazy init check - when all params do not have entry in buffers
model1_params.emplace_back(torch::randn({2,3}));
auto optim1 = torch::optim::AdamW(model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
auto x = torch::ones({10, 5});
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
optimizer.zero_grad();
auto y = model->forward(x).sum();
y.backward();
optimizer.step();
};
step(optim1, model1);
std::vector<int64_t> step_buffers;
std::vector<at::Tensor> exp_average_buffers;
std::vector<at::Tensor> exp_average_sq_buffers;
std::vector<at::Tensor> max_exp_average_sq_buffers;
const auto& params_ = optim1.param_groups()[0].params();
const auto& optim1_state = optim1.state();
for (size_t i = 0; i < params_.size(); i++) {
if(i != (params_.size() - 1)) {
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
const AdamWParamState& curr_state_ = static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
step_buffers.emplace_back(curr_state_.step());
exp_average_buffers.emplace_back(curr_state_.exp_avg());
exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
if(curr_state_.max_exp_avg_sq().defined()) {
max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
}
}
}
// write buffers to the file
auto optim_tempfile_old_format = c10::make_tempfile();
torch::serialize::OutputArchive output_archive;
write_step_buffers(output_archive, "step_buffers", step_buffers);
write_tensors_to_archive(output_archive, "exp_average_buffers", exp_average_buffers);
write_tensors_to_archive(output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
write_tensors_to_archive(output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
output_archive.save_to(optim_tempfile_old_format.name);
auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name);
is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
}
TEST(SerializeTest, Optim_RMSprop) {
auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);

View File

@ -384,6 +384,7 @@ torch_cpp_srcs = [
"torch/csrc/api/src/nn/options/vision.cpp",
"torch/csrc/api/src/optim/adagrad.cpp",
"torch/csrc/api/src/optim/adam.cpp",
"torch/csrc/api/src/optim/adamw.cpp",
"torch/csrc/api/src/optim/lbfgs.cpp",
"torch/csrc/api/src/optim/optimizer.cpp",
"torch/csrc/api/src/optim/rmsprop.cpp",

View File

@ -2,6 +2,7 @@
#include <torch/optim/adagrad.h>
#include <torch/optim/adam.h>
#include <torch/optim/adamw.h>
#include <torch/optim/lbfgs.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/rmsprop.h>

View File

@ -0,0 +1,75 @@
#pragma once
#include <torch/arg.h>
#include <torch/nn/module.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/serialize.h>
#include <utility>
#include <vector>
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
namespace torch {
namespace optim {
struct TORCH_API AdamWOptions : public OptimizerCloneableOptions<AdamWOptions> {
AdamWOptions(double lr = 1e-3);
TORCH_ARG(double, lr) = 1e-3;
typedef std::tuple<double, double> betas_t;
TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999);
TORCH_ARG(double, eps) = 1e-8;
TORCH_ARG(double, weight_decay) = 1e-2;
TORCH_ARG(bool, amsgrad) = false;
public:
void serialize(torch::serialize::InputArchive& archive) override;
void serialize(torch::serialize::OutputArchive& archive) const override;
TORCH_API friend bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs);
~AdamWOptions() = default;
};
struct TORCH_API AdamWParamState : public OptimizerCloneableParamState<AdamWParamState> {
TORCH_ARG(int64_t, step) = 0;
TORCH_ARG(torch::Tensor, exp_avg);
TORCH_ARG(torch::Tensor, exp_avg_sq);
TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {};
public:
void serialize(torch::serialize::InputArchive& archive) override;
void serialize(torch::serialize::OutputArchive& archive) const override;
TORCH_API friend bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs);
~AdamWParamState() = default;
};
class TORCH_API AdamW : public Optimizer {
public:
explicit AdamW(std::vector<OptimizerParamGroup> param_groups,
AdamWOptions defaults = {}) : Optimizer(std::move(param_groups), std::make_unique<AdamWOptions>(defaults)) {
TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr());
TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps());
auto betas = defaults.betas();
TORCH_CHECK(0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, "Invalid beta parameter at index 0: ", std::get<0>(betas));
TORCH_CHECK(0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, "Invalid beta parameter at index 1: ", std::get<1>(betas));
TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay());
}
explicit AdamW(
std::vector<Tensor> params,
AdamWOptions defaults = {}) : AdamW({std::move(OptimizerParamGroup(params))}, defaults) {}
torch::Tensor step(LossClosure closure = nullptr) override;
void save(serialize::OutputArchive& archive) const override;
void load(serialize::InputArchive& archive) override;
private:
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
_TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(AdamW);
}
};
} // namespace optim
} // namespace torch

View File

@ -0,0 +1,170 @@
#include <torch/optim/adamw.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/nn/module.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
#include <cmath>
#include <functional>
namespace torch {
namespace optim {
AdamWOptions::AdamWOptions(double lr) : lr_(lr) {}
bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs) {
return (lhs.lr() == rhs.lr()) &&
(std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) &&
(std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) &&
(lhs.eps() == rhs.eps()) &&
(lhs.weight_decay() == rhs.weight_decay() &&
(lhs.amsgrad() == rhs.amsgrad()));
}
void AdamWOptions::serialize(torch::serialize::OutputArchive& archive) const {
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(amsgrad);
}
void AdamWOptions::serialize(torch::serialize::InputArchive& archive) {
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, amsgrad);
}
bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs) {
return (lhs.step() == rhs.step()) &&
torch::equal(lhs.exp_avg(), rhs.exp_avg()) &&
torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) &&
torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq());
}
void AdamWParamState::serialize(torch::serialize::OutputArchive& archive) const {
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_exp_avg_sq);
}
void AdamWParamState::serialize(torch::serialize::InputArchive& archive) {
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg_sq);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq);
}
Tensor AdamW::step(LossClosure closure) {
NoGradGuard no_grad;
Tensor loss = {};
if (closure != nullptr) {
at::AutoGradMode enable_grad(true);
loss = closure();
}
for (auto& group : param_groups_) {
for (auto& p : group.params()) {
if (!p.grad().defined()) {
continue;
}
auto grad = p.grad();
TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients"/*, please consider SparseAdamW instead*/);
auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
auto& options = static_cast<AdamWOptions&>(group.options());
// Perform stepweight decay
if(options.weight_decay() != 0) {
p.mul_(1 - options.lr() * options.weight_decay());
}
// State initialization
if(param_state == state_.end()) {
auto state = std::make_unique<AdamWParamState>();
state->step(0);
// Exponential moving average of gradient values
state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve));
// Exponential moving average of squared gradient values
state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
if(options.amsgrad()) {
// Maintains max of all exp. moving avg. of sq. grad. values
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
}
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state);
}
auto& state = static_cast<AdamWParamState&>(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
auto& exp_avg = state.exp_avg();
auto& exp_avg_sq = state.exp_avg_sq();
auto& max_exp_avg_sq = state.max_exp_avg_sq();
state.step(state.step()+1);
auto beta1 = std::get<0>(options.betas());
auto beta2 = std::get<1>(options.betas());
auto bias_correction1 = 1 - std::pow(beta1, state.step());
auto bias_correction2 = 1 - std::pow(beta2, state.step());
// Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, 1 - beta1);
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2);
Tensor denom;
if(options.amsgrad()) {
// Maintains the maximum of all 2nd moment running avg. till now
torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq);
// Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
} else {
denom = (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
}
auto step_size = options.lr() / bias_correction1;
p.addcdiv_(exp_avg, denom, -step_size);
}
}
return loss;
}
void AdamW::save(serialize::OutputArchive& archive) const {
serialize(*this, archive);
}
void AdamW::load(serialize::InputArchive& archive) {
IValue pytorch_version;
if (archive.try_read("pytorch_version", pytorch_version)) {
serialize(*this, archive);
}
else { // deserializing archives saved in old format (prior to version 1.5.0)
TORCH_WARN(
"Your serialized AdamW optimizer is still using the old serialization format. "
"You should re-save your AdamW optimizer to use the new serialization format.");
std::vector<int64_t> step_buffers;
std::vector<at::Tensor> exp_average_buffers;
std::vector<at::Tensor> exp_average_sq_buffers;
std::vector<at::Tensor> max_exp_average_sq_buffers;
torch::optim::serialize(archive, "step_buffers", step_buffers);
torch::optim::serialize(archive, "exp_average_buffers", exp_average_buffers);
torch::optim::serialize(archive, "exp_average_sq_buffers", exp_average_sq_buffers);
torch::optim::serialize(archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
// since there were no param_groups prior to version 1.5.0, assuming all tensors are now in one param_group
std::vector<Tensor> params = param_groups_.at(0).params();
for (size_t idx = 0; idx < step_buffers.size(); idx++) {
auto state = std::make_unique<AdamWParamState>();
state->step(step_buffers.at(idx));
state->exp_avg(exp_average_buffers.at(idx));
state->exp_avg_sq(exp_average_sq_buffers.at(idx));
if (idx < max_exp_average_sq_buffers.size()) {
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
}
state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] = std::move(state);
}
}
}
} // namespace optim
} // namespace torch