[C++ API Parity] [Optimizers] added closure to optimizers (#34790)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34790

Differential Revision: D20468361

Pulled By: anjali411

fbshipit-source-id: 1c6115d735b211dc2bedf002d58931cb32cf657a
This commit is contained in:
anjali411
2020-03-16 07:48:27 -07:00
committed by Facebook GitHub Bot
parent bdd7dbfd4b
commit 762be86e63
11 changed files with 41 additions and 14 deletions

View File

@ -168,7 +168,7 @@ TEST(OptimTest, OptimizerAccessors) {
TEST(OptimTest, BasicInterface) {
struct MyOptimizer : Optimizer {
using Optimizer::Optimizer;
void step() override {}
torch::Tensor step(LossClosure closure = nullptr) override { return {};}
};
std::vector<torch::Tensor> parameters = {
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};

View File

@ -100,7 +100,8 @@ void test_serialize_optimizer(DerivedOptimizerOptions options) {
optimizer.zero_grad();
auto y = model->forward(x).sum();
y.backward();
optimizer.step();
auto closure = []() { return torch::tensor({10}); };
optimizer.step(closure);
};
// Do 2 steps of model1