Fix LBFGS wolfe max iteration (#161488)

Fixes #91581 , based on #135026

## Test Result

```bash
pytest test/test_optim.py

.........
========================== 1473 passed, 242 skipped in 2412.49s (0:40:12) ===========================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161488
Approved by: https://github.com/albanD
This commit is contained in:
zeshengzong
2025-09-16 12:07:46 +00:00
committed by PyTorch MergeBot
parent 6926710adf
commit fa127d9b20
2 changed files with 36 additions and 1 deletions

View File

@ -2305,6 +2305,34 @@ class TestOptimRenewed(TestCase):
for state in optim.state.values(): for state in optim.state.values():
self.assertGreater(len(state), 0) self.assertGreater(len(state), 0)
@parametrize("dtype", [torch.float32])
def test_step_iteration(self, device, dtype):
def _get_model_and_input_tensor(device, dtype):
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
)
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
model.to(dtype=dtype, device=device)
return model, input
counter = 0
def fwd_bwd(optim, mod, i):
nonlocal counter
counter += 1
optim.zero_grad()
loss = mod(i).sum()
loss.backward()
return loss
model, input = _get_model_and_input_tensor(device, dtype)
optimizer = torch.optim.LBFGS(
model.parameters(), max_iter=1, max_eval=5, line_search_fn="strong_wolfe"
)
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
self.assertEqual(counter, 6)
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True) instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)

View File

@ -442,7 +442,14 @@ class LBFGS(Optimizer):
return self._directional_evaluate(closure, x, t, d) return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_wolfe( loss, flat_grad, t, ls_func_evals = _strong_wolfe(
obj_func, x_init, t, d, loss, flat_grad, gtd obj_func,
x_init,
t,
d,
loss,
flat_grad,
gtd,
max_ls=max_eval - current_evals,
) )
self._add_grad(t, d) self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad opt_cond = flat_grad.abs().max() <= tolerance_grad