mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[optim] Fix: wrong ASGD implementation (#125440)
> previous: Originally, the variables `new_eta` and `new_mu` would be constructed `len(grouped_mus)` times, but each of their values is the same and won't be changed. Therefore, it can be simplified using Python list multiplication, which only constructs one tensor. - [X] Ill assumption that every param will have the same step. - [x] DIfferent implementation between `foreach=Ture` and `foreach=False`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125440 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
5af4b49285
commit
2c5ad9a3d7
@ -604,8 +604,16 @@ class TestOptimRenewed(TestCase):
|
||||
for input, model, optimizer in zip(inputs, models, optimizers):
|
||||
optimizer.zero_grad()
|
||||
|
||||
if i == 3:
|
||||
# Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
|
||||
# is same as the step in 'forloop'.
|
||||
model[2].requires_grad_(False)
|
||||
if i == 5:
|
||||
# Unfreeze the layer after 2 iters.
|
||||
model[2].requires_grad_(True)
|
||||
|
||||
# Test that step behaves as expected (a no-op) when grads are set to None
|
||||
if i != 3:
|
||||
if i != 2:
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
@ -19,6 +19,7 @@ from torch._prims_common import (
|
||||
corresponding_real_dtype,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
FloatLike,
|
||||
IntLike,
|
||||
make_contiguous_strides_for,
|
||||
Number,
|
||||
@ -3286,6 +3287,15 @@ def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
@register_meta([aten._foreach_pow_.Scalar])
|
||||
def meta__foreach_pow__scalar(self, exponent):
|
||||
torch._check(
|
||||
isinstance(exponent, FloatLike),
|
||||
lambda: f"exponent must be a float but got {type(exponent)}",
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@register_meta([aten._foreach_pow.ScalarAndTensor])
|
||||
def meta__foreach_pow_scalar_and_tensor(self, exponent):
|
||||
# Only foreach_pow has a ScalarAndTensor method and needs special
|
||||
|
@ -22,13 +22,6 @@ from .optimizer import (
|
||||
__all__ = ["ASGD", "asgd"]
|
||||
|
||||
|
||||
def _to_tensor(x, device=None):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, device=device)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ASGD(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
@ -264,9 +257,9 @@ def _single_tensor_asgd(
|
||||
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
|
||||
else:
|
||||
step = _get_value(step_t)
|
||||
new_eta = _to_tensor(lr / ((1 + lambd * lr * step) ** alpha))
|
||||
new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
|
||||
eta.copy_(new_eta)
|
||||
new_mu = _to_tensor(1 / max(1, step - t0))
|
||||
new_mu = torch.as_tensor(1 / max(1, step - t0))
|
||||
mu.copy_(new_mu)
|
||||
|
||||
|
||||
@ -381,27 +374,23 @@ def _multi_tensor_asgd(
|
||||
torch._foreach_copy_(grouped_mus, new_mus)
|
||||
del new_mus
|
||||
|
||||
# update eta = lr / (1 + lambd * lr * step^alpha)
|
||||
new_etas = torch._foreach_pow(grouped_state_steps, alpha)
|
||||
torch._foreach_mul_(new_etas, lambd)
|
||||
# update eta = lr / ((1 + lambd * lr * step)^alpha)
|
||||
new_etas = torch._foreach_mul(grouped_state_steps, lambd)
|
||||
torch._foreach_mul_(new_etas, lr)
|
||||
torch._foreach_add_(new_etas, 1)
|
||||
torch._foreach_pow_(new_etas, alpha)
|
||||
torch._foreach_reciprocal_(new_etas)
|
||||
torch._foreach_mul_(new_etas, lr)
|
||||
torch._foreach_copy_(grouped_etas, new_etas)
|
||||
else:
|
||||
step = grouped_state_steps[0].item()
|
||||
new_etas = []
|
||||
new_mus = []
|
||||
|
||||
for i in range(len(grouped_mus)):
|
||||
new_eta = _to_tensor(
|
||||
lr / (1 + lambd * lr * step**alpha), device=device
|
||||
)
|
||||
new_etas.append(new_eta)
|
||||
new_mu = _to_tensor(1 / max(1, step - t0), device=device)
|
||||
new_mus.append(new_mu)
|
||||
|
||||
new_etas = [
|
||||
torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
|
||||
for step in grouped_state_steps
|
||||
]
|
||||
new_mus = [
|
||||
torch.as_tensor(1 / max(1, step - t0), device=device)
|
||||
for step in grouped_state_steps
|
||||
]
|
||||
torch._foreach_copy_(grouped_etas, new_etas)
|
||||
torch._foreach_copy_(grouped_mus, new_mus)
|
||||
|
||||
|
@ -590,6 +590,7 @@ def optim_inputs_func_asgd(device, dtype=None):
|
||||
]
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
|
||||
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
|
||||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||||
@ -1432,6 +1433,13 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_defaults_changed_to_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip(
|
||||
"ASGD internally changes the weights even with zero grad"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_step_is_noop_for_zero_grads",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
Reference in New Issue
Block a user