mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adafactor forloop basic impl (#129905)
#109581 At this point, the vanilla implementation (the default) is good. Docs: https://docs-preview.pytorch.org/pytorch/pytorch/129905/generated/torch.optim.Adafactor.html#torch.optim.Adafactor Specifically, the impl in this PR, which attempts to replicate the paper, ``` optim = torch.optim.Adafactor([weight]) ``` is close enough to https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.AdaFactor ``` optim_c = AdaFactor([weight], betas=(0, 0.999), scale_parameter=False) ``` is close enough to https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adafactor ``` optim = keras.optimizers.Adafactor(learning_rate=0.01) ``` The three results respectively for the same randomly generated weights: ``` # ours tensor([[ 0.3807594, -0.3912092], [ 0.0762539, 0.5377805], [ 0.2459473, 0.4662207]]) # pytorch-optimizer tensor([[ 0.3807592, -0.3912172], [ 0.0762507, 0.5377818], [ 0.2459457, 0.4662213]]) # keras array([[ 0.38076326, -0.39121315], [ 0.0762547 , 0.5377859 ], [ 0.24594972, 0.46622536]], dtype=float32) ``` This gives me confidence to move forward in speeding up the implementation now that a baseline has been established. If you're curious about differences: * keras assigns step_size (rho_t in their code) to `min(lr, 1 / sqrt(step)` whereas the OG impl uses a hardcoded 0.01 instead of lr. We do the same thing as keras, but our lr default is 0.01. * We differ from the pytorch-optimizers default in that our default will not track momentum (thus `beta1=0`) and we do not apply parameter scaling. <details> Keras collab: https://colab.research.google.com/drive/1i3xF8ChL7TWKJGV_5v_5nMhXKnYmQQ06?usp=sharing My script repro: ``` import torch from pytorch_optimizer import AdaFactor torch.set_printoptions(precision=7) weight = torch.tensor([[ 0.37697506, -0.39500135], [ 0.07246649, 0.53399765], [ 0.24216151, 0.46243715]], dtype=torch.float32) # bias = torch.tensor([0, 0], dtype=torch.float32) weight.grad = torch.tensor([[-0.5940447, -0.7743838], [-0.5940447, -0.7743838], [-0.5940447, -0.7743838]], dtype=torch.float32) # bias.grad = torch.tensor([-2.5027974, 1.5422692], dtype=torch.float32) weight_c = weight.clone() weight_c.grad = weight.grad.clone() optim = torch.optim.Adafactor([weight]) optim.step() print(weight) optim_c = AdaFactor([weight_c], betas=(0, 0.999), scale_parameter=False) optim_c.step() print(weight_c) ``` <details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/129905 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8956c9fe6
commit
9c4cf866c2
@ -1383,7 +1383,6 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_can_load_older_state_dict(self, device, dtype, optim_info):
|
||||
new_flags = ["maximize", "foreach", "fused", "differentiable", "capturable"]
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
@ -1417,7 +1416,7 @@ class TestOptimRenewed(TestCase):
|
||||
old_state_dict = deepcopy(optimizer.state_dict())
|
||||
old_state_dict_pg = old_state_dict["param_groups"]
|
||||
for group in old_state_dict_pg:
|
||||
for flag in new_flags:
|
||||
for flag in optim_info.not_og_supported_flags:
|
||||
if flag in group:
|
||||
del group[flag]
|
||||
|
||||
|
Reference in New Issue
Block a user