mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[C++ API Parity] [Optimizers] added closure to optimizers (#34790)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34790 Differential Revision: D20468361 Pulled By: anjali411 fbshipit-source-id: 1c6115d735b211dc2bedf002d58931cb32cf657a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bdd7dbfd4b
commit
762be86e63
@ -168,7 +168,7 @@ TEST(OptimTest, OptimizerAccessors) {
|
||||
TEST(OptimTest, BasicInterface) {
|
||||
struct MyOptimizer : Optimizer {
|
||||
using Optimizer::Optimizer;
|
||||
void step() override {}
|
||||
torch::Tensor step(LossClosure closure = nullptr) override { return {};}
|
||||
};
|
||||
std::vector<torch::Tensor> parameters = {
|
||||
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
|
||||
|
@ -100,7 +100,8 @@ void test_serialize_optimizer(DerivedOptimizerOptions options) {
|
||||
optimizer.zero_grad();
|
||||
auto y = model->forward(x).sum();
|
||||
y.backward();
|
||||
optimizer.step();
|
||||
auto closure = []() { return torch::tensor({10}); };
|
||||
optimizer.step(closure);
|
||||
};
|
||||
|
||||
// Do 2 steps of model1
|
||||
|
Reference in New Issue
Block a user