mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix LBFGS
wolfe max iteration (#161488)
Fixes #91581 , based on #135026 ## Test Result ```bash pytest test/test_optim.py ......... ========================== 1473 passed, 242 skipped in 2412.49s (0:40:12) =========================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161488 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6926710adf
commit
fa127d9b20
@ -2305,6 +2305,34 @@ class TestOptimRenewed(TestCase):
|
|||||||
for state in optim.state.values():
|
for state in optim.state.values():
|
||||||
self.assertGreater(len(state), 0)
|
self.assertGreater(len(state), 0)
|
||||||
|
|
||||||
|
@parametrize("dtype", [torch.float32])
|
||||||
|
def test_step_iteration(self, device, dtype):
|
||||||
|
def _get_model_and_input_tensor(device, dtype):
|
||||||
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(4, 2, 1, stride=2),
|
||||||
|
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
|
||||||
|
)
|
||||||
|
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
|
||||||
|
model.to(dtype=dtype, device=device)
|
||||||
|
return model, input
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
def fwd_bwd(optim, mod, i):
|
||||||
|
nonlocal counter
|
||||||
|
counter += 1
|
||||||
|
optim.zero_grad()
|
||||||
|
loss = mod(i).sum()
|
||||||
|
loss.backward()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
model, input = _get_model_and_input_tensor(device, dtype)
|
||||||
|
optimizer = torch.optim.LBFGS(
|
||||||
|
model.parameters(), max_iter=1, max_eval=5, line_search_fn="strong_wolfe"
|
||||||
|
)
|
||||||
|
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
|
||||||
|
self.assertEqual(counter, 6)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
|
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
|
||||||
|
|
||||||
|
@ -442,7 +442,14 @@ class LBFGS(Optimizer):
|
|||||||
return self._directional_evaluate(closure, x, t, d)
|
return self._directional_evaluate(closure, x, t, d)
|
||||||
|
|
||||||
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
|
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
|
||||||
obj_func, x_init, t, d, loss, flat_grad, gtd
|
obj_func,
|
||||||
|
x_init,
|
||||||
|
t,
|
||||||
|
d,
|
||||||
|
loss,
|
||||||
|
flat_grad,
|
||||||
|
gtd,
|
||||||
|
max_ls=max_eval - current_evals,
|
||||||
)
|
)
|
||||||
self._add_grad(t, d)
|
self._add_grad(t, d)
|
||||||
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
||||||
|
Reference in New Issue
Block a user