[C++ API Parity] Add xor_convergence test for lbfgs (#35001)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35001

Differential Revision: D20548983

Pulled By: anjali411

fbshipit-source-id: 1f858635d0680c0109d1ef348b7df4d3844fe0a6
This commit is contained in:
anjali411
2020-03-20 06:54:15 -07:00
committed by Facebook GitHub Bot
parent 1c958f8ef9
commit 781f590f33

View File

@ -26,7 +26,7 @@ bool test_optimizer_xor(Options options) {
Linear(8, 1),
Functional(torch::sigmoid));
const int64_t kBatchSize = 4;
const int64_t kBatchSize = 50;
const int64_t kMaximumNumberOfEpochs = 3000;
OptimizerClass optimizer(model->parameters(), options);
@ -40,13 +40,21 @@ bool test_optimizer_xor(Options options) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
}
inputs.set_requires_grad(true);
optimizer.zero_grad();
auto x = model->forward(inputs);
torch::Tensor loss = torch::binary_cross_entropy(x, labels);
loss.backward();
optimizer.step();
inputs.set_requires_grad(true);
auto step = [&](OptimizerClass& optimizer, Sequential model, torch::Tensor inputs, torch::Tensor labels) {
auto closure = [&]() {
optimizer.zero_grad();
auto x = model->forward(inputs);
auto loss = torch::binary_cross_entropy(x, labels);
loss.backward();
return loss;
};
return optimizer.step(closure);
};
torch::Tensor loss = step(optimizer, model, inputs, labels);
running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
if (epoch > kMaximumNumberOfEpochs) {
@ -198,6 +206,11 @@ TEST(OptimTest, XORConvergence_SGD) {
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
}
TEST(OptimTest, XORConvergence_LBFGS) {
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
}
TEST(OptimTest, XORConvergence_Adagrad) {
ASSERT_TRUE(test_optimizer_xor<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));