Add support for differentiable weight decay (#143679)

(Actual) second PR in a larger project to broaden support for differentiable optimizers with @janeyx99!

In this PR, I did a lot of pattern matching from the previous PR to add support for differentiable weight_decay.

And also added a single new line on line 359 (previously line 352) to make the code from the last PR a little easier to read

Continuation of progress on #141832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143679
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
Emmett Bicker
2024-12-27 23:14:41 +00:00
committed by PyTorch MergeBot
parent c0c7f881da
commit 0de661dc27
2 changed files with 62 additions and 2 deletions

View File

@ -426,6 +426,57 @@ class TestDifferentiableOptimizer(TestCase):
),
)
def test_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.9, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
SGD,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)
def test_differentiable_weight_decay_and_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.9, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
kwargs: dict[str, Any] = {
"weight_decay": weight_decay,
"lr": lr,
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
SGD,
kwargs, # includes lr & weight_decay
*state.values(),
*kwargs.values(),
),
)
if __name__ == "__main__":
print("These tests should be run through test/test_optim.py instead")

View File

@ -30,7 +30,7 @@ class SGD(Optimizer): # noqa: D101
lr: Union[float, Tensor] = 1e-3,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
weight_decay: Union[float, Tensor] = 0,
nesterov: bool = False,
*,
maximize: bool = False,
@ -334,7 +334,15 @@ def _single_tensor_sgd(
grad = grads[i] if not maximize else -grads[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Nested if is necessary to bypass jitscript rules
if isinstance(weight_decay, Tensor):
if weight_decay.requires_grad:
# usually this is the differentiable path, which is why the param.clone() is needed
grad = grad.addcmul_(param.clone(), weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i]
@ -349,6 +357,7 @@ def _single_tensor_sgd(
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
# Nested if is necessary to bypass jitscript rules
if isinstance(lr, Tensor):
if lr.requires_grad: