mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix LBFGS
wolfe max iteration (#161488)
Fixes #91581 , based on #135026 ## Test Result ```bash pytest test/test_optim.py ......... ========================== 1473 passed, 242 skipped in 2412.49s (0:40:12) =========================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161488 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6926710adf
commit
fa127d9b20
@ -2305,6 +2305,34 @@ class TestOptimRenewed(TestCase):
|
||||
for state in optim.state.values():
|
||||
self.assertGreater(len(state), 0)
|
||||
|
||||
@parametrize("dtype", [torch.float32])
|
||||
def test_step_iteration(self, device, dtype):
|
||||
def _get_model_and_input_tensor(device, dtype):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(4, 2, 1, stride=2),
|
||||
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
|
||||
)
|
||||
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
|
||||
model.to(dtype=dtype, device=device)
|
||||
return model, input
|
||||
|
||||
counter = 0
|
||||
|
||||
def fwd_bwd(optim, mod, i):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
optim.zero_grad()
|
||||
loss = mod(i).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
model, input = _get_model_and_input_tensor(device, dtype)
|
||||
optimizer = torch.optim.LBFGS(
|
||||
model.parameters(), max_iter=1, max_eval=5, line_search_fn="strong_wolfe"
|
||||
)
|
||||
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
|
||||
self.assertEqual(counter, 6)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
|
||||
|
||||
|
@ -442,7 +442,14 @@ class LBFGS(Optimizer):
|
||||
return self._directional_evaluate(closure, x, t, d)
|
||||
|
||||
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
|
||||
obj_func, x_init, t, d, loss, flat_grad, gtd
|
||||
obj_func,
|
||||
x_init,
|
||||
t,
|
||||
d,
|
||||
loss,
|
||||
flat_grad,
|
||||
gtd,
|
||||
max_ls=max_eval - current_evals,
|
||||
)
|
||||
self._add_grad(t, d)
|
||||
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
||||
|
Reference in New Issue
Block a user