mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for differentiable weight decay (#143679)
(Actual) second PR in a larger project to broaden support for differentiable optimizers with @janeyx99! In this PR, I did a lot of pattern matching from the previous PR to add support for differentiable weight_decay. And also added a single new line on line 359 (previously line 352) to make the code from the last PR a little easier to read Continuation of progress on #141832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143679 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
c0c7f881da
commit
0de661dc27
@ -426,6 +426,57 @@ class TestDifferentiableOptimizer(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_differentiable_weight_decay(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.9, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
mbuff = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
state = {"momentum_buffer": mbuff}
|
||||
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
SGD,
|
||||
kwargs, # includes weight_decay
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_differentiable_weight_decay_and_lr(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
weight_decay = torch.tensor(0.9, requires_grad=True, dtype=torch.float64)
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
mbuff = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
state = {"momentum_buffer": mbuff}
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"weight_decay": weight_decay,
|
||||
"lr": lr,
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
SGD,
|
||||
kwargs, # includes lr & weight_decay
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("These tests should be run through test/test_optim.py instead")
|
||||
|
@ -30,7 +30,7 @@ class SGD(Optimizer): # noqa: D101
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
momentum: float = 0,
|
||||
dampening: float = 0,
|
||||
weight_decay: float = 0,
|
||||
weight_decay: Union[float, Tensor] = 0,
|
||||
nesterov: bool = False,
|
||||
*,
|
||||
maximize: bool = False,
|
||||
@ -334,7 +334,15 @@ def _single_tensor_sgd(
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
|
||||
if weight_decay != 0:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if isinstance(weight_decay, Tensor):
|
||||
if weight_decay.requires_grad:
|
||||
# usually this is the differentiable path, which is why the param.clone() is needed
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if momentum != 0:
|
||||
buf = momentum_buffer_list[i]
|
||||
@ -349,6 +357,7 @@ def _single_tensor_sgd(
|
||||
grad = grad.add(buf, alpha=momentum)
|
||||
else:
|
||||
grad = buf
|
||||
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if isinstance(lr, Tensor):
|
||||
if lr.requires_grad:
|
||||
|
Reference in New Issue
Block a user