Revert "[optim] Fix: wrong ASGD implementation (#125440)"

This reverts commit 2c5ad9a3d7ea79ca897aec153a401f4b9175a717.

Reverted https://github.com/pytorch/pytorch/pull/125440 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it looks like there is a linter failure coming from this change ([comment](https://github.com/pytorch/pytorch/pull/125440#issuecomment-2113833108))
This commit is contained in:
PyTorch MergeBot
2024-05-16 02:12:29 +00:00
parent 175c18af81
commit e3c5d1b7d7
4 changed files with 25 additions and 40 deletions

View File

@ -604,16 +604,8 @@ class TestOptimRenewed(TestCase):
for input, model, optimizer in zip(inputs, models, optimizers):
optimizer.zero_grad()
if i == 3:
# Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
# is same as the step in 'forloop'.
model[2].requires_grad_(False)
if i == 5:
# Unfreeze the layer after 2 iters.
model[2].requires_grad_(True)
# Test that step behaves as expected (a no-op) when grads are set to None
if i != 2:
if i != 3:
output = model(input)
loss = output.sum()
loss.backward()

View File

@ -19,7 +19,6 @@ from torch._prims_common import (
corresponding_real_dtype,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
FloatLike,
IntLike,
make_contiguous_strides_for,
Number,
@ -3287,15 +3286,6 @@ def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
return
@register_meta([aten._foreach_pow_.Scalar])
def meta__foreach_pow__scalar(self, exponent):
torch._check(
isinstance(exponent, FloatLike),
lambda: f"exponent must be a float but got {type(exponent)}",
)
return
@register_meta([aten._foreach_pow.ScalarAndTensor])
def meta__foreach_pow_scalar_and_tensor(self, exponent):
# Only foreach_pow has a ScalarAndTensor method and needs special

View File

@ -22,6 +22,13 @@ from .optimizer import (
__all__ = ["ASGD", "asgd"]
def _to_tensor(x, device=None):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, device=device)
return x
class ASGD(Optimizer):
def __init__(
self,
@ -257,9 +264,9 @@ def _single_tensor_asgd(
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else:
step = _get_value(step_t)
new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
new_eta = _to_tensor(lr / ((1 + lambd * lr * step) ** alpha))
eta.copy_(new_eta)
new_mu = torch.as_tensor(1 / max(1, step - t0))
new_mu = _to_tensor(1 / max(1, step - t0))
mu.copy_(new_mu)
@ -374,23 +381,27 @@ def _multi_tensor_asgd(
torch._foreach_copy_(grouped_mus, new_mus)
del new_mus
# update eta = lr / ((1 + lambd * lr * step)^alpha)
new_etas = torch._foreach_mul(grouped_state_steps, lambd)
# update eta = lr / (1 + lambd * lr * step^alpha)
new_etas = torch._foreach_pow(grouped_state_steps, alpha)
torch._foreach_mul_(new_etas, lambd)
torch._foreach_mul_(new_etas, lr)
torch._foreach_add_(new_etas, 1)
torch._foreach_pow_(new_etas, alpha)
torch._foreach_reciprocal_(new_etas)
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
else:
new_etas = [
torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
for step in grouped_state_steps
]
new_mus = [
torch.as_tensor(1 / max(1, step - t0), device=device)
for step in grouped_state_steps
]
step = grouped_state_steps[0].item()
new_etas = []
new_mus = []
for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * step**alpha), device=device
)
new_etas.append(new_eta)
new_mu = _to_tensor(1 / max(1, step - t0), device=device)
new_mus.append(new_mu)
torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)

View File

@ -590,7 +590,6 @@ def optim_inputs_func_asgd(device, dtype=None):
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
@ -1451,13 +1450,6 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
DecorateInfo(
unittest.skip(
"ASGD internally changes the weights even with zero grad"
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
),
),
OptimizerInfo(