mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[C++ API Parity] Add xor_convergence test for lbfgs (#35001)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35001 Differential Revision: D20548983 Pulled By: anjali411 fbshipit-source-id: 1f858635d0680c0109d1ef348b7df4d3844fe0a6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1c958f8ef9
commit
781f590f33
@ -26,7 +26,7 @@ bool test_optimizer_xor(Options options) {
|
||||
Linear(8, 1),
|
||||
Functional(torch::sigmoid));
|
||||
|
||||
const int64_t kBatchSize = 4;
|
||||
const int64_t kBatchSize = 50;
|
||||
const int64_t kMaximumNumberOfEpochs = 3000;
|
||||
|
||||
OptimizerClass optimizer(model->parameters(), options);
|
||||
@ -40,13 +40,21 @@ bool test_optimizer_xor(Options options) {
|
||||
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
||||
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
|
||||
}
|
||||
inputs.set_requires_grad(true);
|
||||
optimizer.zero_grad();
|
||||
auto x = model->forward(inputs);
|
||||
torch::Tensor loss = torch::binary_cross_entropy(x, labels);
|
||||
loss.backward();
|
||||
|
||||
optimizer.step();
|
||||
inputs.set_requires_grad(true);
|
||||
|
||||
auto step = [&](OptimizerClass& optimizer, Sequential model, torch::Tensor inputs, torch::Tensor labels) {
|
||||
auto closure = [&]() {
|
||||
optimizer.zero_grad();
|
||||
auto x = model->forward(inputs);
|
||||
auto loss = torch::binary_cross_entropy(x, labels);
|
||||
loss.backward();
|
||||
return loss;
|
||||
};
|
||||
return optimizer.step(closure);
|
||||
};
|
||||
|
||||
torch::Tensor loss = step(optimizer, model, inputs, labels);
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
|
||||
if (epoch > kMaximumNumberOfEpochs) {
|
||||
@ -198,6 +206,11 @@ TEST(OptimTest, XORConvergence_SGD) {
|
||||
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
|
||||
}
|
||||
|
||||
TEST(OptimTest, XORConvergence_LBFGS) {
|
||||
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
|
||||
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
|
||||
}
|
||||
|
||||
TEST(OptimTest, XORConvergence_Adagrad) {
|
||||
ASSERT_TRUE(test_optimizer_xor<Adagrad>(
|
||||
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
|
||||
|
Reference in New Issue
Block a user