mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding support for differentiable lr, weight_decay, and betas in Adam/AdamW (#143726)
Third PR in a series of PRs to broaden differentiable optimizer support w/ @janeyx99 (sorry for pinging over the holidays! I just wanted to put this one out but I am definitely not asking for review or anything like that rn) This is also going to probably be my last PR before the holidays! Note: This is a branch of #143710 -- I've never worked on a branch of a branch before so I wasn't sure about the protocol so I thought I'd just made the PR and wait until that one gets merged. This is adding support for differentiable lr, weight_decay, and betas to Adam and AdamW (but after refactoring AdamW into an Adam subclass, it's really just changing code in torch/optim/adam.py) I had one main thing I was wondering about, which is that adam already has a differentiable flag built in, so I have code like this ```py if differentiable and isinstance(beta2, Tensor): if beta2.requires_grad: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` That I could definitely simplify to just ```py if differentiable and isinstance(beta2, Tensor): exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` It would definitely be a little slower in the case that it's differentiable but doesn't need a grad for beta2, but the code would also be a lot more clear and I'm debating speed vs future code usability. Also the line in the above example: ```py exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) ``` was concerning to me because it is considerably more expensive than `value=1 - beta2`, but I couldn't think of a better way to do it. Further work on #141832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143726 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
a7915c56f6
commit
92d8965082
@ -76,12 +76,22 @@ def _multistep_backprop_diff_hyperparams_fn(
|
||||
|
||||
# This copy is necessary so the update on line 78 doesn't overwrite the original kwargs values
|
||||
kwargs = kwargs.copy()
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck
|
||||
if "beta1" in kwargs or "beta2" in kwargs:
|
||||
# Prevent just one beta kwarg from being passed in
|
||||
assert (
|
||||
"beta1" in kwargs and "beta2" in kwargs
|
||||
), "Both betas should be defined in kwargs"
|
||||
kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))})
|
||||
|
||||
kwargs.update(
|
||||
{k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
)
|
||||
differentiable_kwargs = [
|
||||
v for v in kwargs.values() if isinstance(v, torch.Tensor) and v.requires_grad
|
||||
]
|
||||
] + (list(kwargs["betas"]) if "betas" in kwargs else [])
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
@ -104,6 +114,10 @@ def _multistep_backprop_diff_hyperparams_fn(
|
||||
meta_loss = loss
|
||||
meta_loss.backward(inputs=(*differentiable_kwargs,), create_graph=True)
|
||||
|
||||
# Extra check to make sure the test properly computed a gradient for all kwargs
|
||||
for kwarg in differentiable_kwargs:
|
||||
assert kwarg.grad is not None
|
||||
|
||||
return (
|
||||
(meta_loss,)
|
||||
+ tuple(
|
||||
@ -111,11 +125,7 @@ def _multistep_backprop_diff_hyperparams_fn(
|
||||
for v in optimizer.state[params].values()
|
||||
if isinstance(v, torch.Tensor) and v.requires_grad
|
||||
)
|
||||
+ tuple(
|
||||
v
|
||||
for v in kwargs.values()
|
||||
if isinstance(v, torch.Tensor) and v.requires_grad
|
||||
)
|
||||
+ tuple(differentiable_kwargs)
|
||||
)
|
||||
|
||||
|
||||
@ -404,6 +414,276 @@ class TestDifferentiableOptimizer(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_lr(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes lr
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_weight_decay(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes weight_decay
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_betas(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
lr = torch.tensor([0.001], requires_grad=True, dtype=torch.float64)
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"lr": lr,
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_all_hyperparams(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"lr": lr,
|
||||
"weight_decay": weight_decay,
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_lr(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes lr
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_weight_decay(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes weight_decay
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_betas(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_all_hyperparams(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"lr": lr,
|
||||
"weight_decay": weight_decay,
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_differentiable_lr(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
@ -402,7 +402,14 @@ def _single_tensor_adam(
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(weight_decay, Tensor):
|
||||
if weight_decay.requires_grad:
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
@ -429,13 +436,43 @@ def _single_tensor_adam(
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.lerp_(grad, 1 - device_beta1)
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(beta2, Tensor):
|
||||
if beta2.requires_grad:
|
||||
# Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
|
||||
# Showing equivalence of differentiable path and nondifferentiable path
|
||||
# expavg * b2 + grad^2 * (1-b2)
|
||||
# add expavg * (1-b2) - expavg * (1-b2) = 0
|
||||
# expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
|
||||
# expavg - expavg * (1-b2) + grad^2 * (1-b2)
|
||||
# expavg + (grad^2 - expavg) * (1-b2)
|
||||
# expavg.lerp(grad^2, 1-beta2)
|
||||
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
if capturable or differentiable:
|
||||
step = step_t
|
||||
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(beta1, Tensor):
|
||||
if beta1.requires_grad:
|
||||
bias_correction1 = 1 - beta1 ** step.clone()
|
||||
else:
|
||||
bias_correction1 = 1 - beta1**step
|
||||
else:
|
||||
bias_correction1 = 1 - beta1**step
|
||||
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(beta2, Tensor):
|
||||
if beta2.requires_grad:
|
||||
bias_correction2 = 1 - beta2 ** step.clone()
|
||||
else:
|
||||
bias_correction2 = 1 - beta2**step
|
||||
else:
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
step_size_neg = step_size.neg()
|
||||
@ -462,7 +499,10 @@ def _single_tensor_adam(
|
||||
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
|
||||
).add_(eps / step_size_neg)
|
||||
|
||||
param.addcdiv_(exp_avg, denom)
|
||||
if differentiable:
|
||||
param.addcdiv_(exp_avg.clone(), denom)
|
||||
else:
|
||||
param.addcdiv_(exp_avg, denom)
|
||||
else:
|
||||
step = _get_value(step_t)
|
||||
|
||||
|
Reference in New Issue
Block a user