mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
To refactor Sparse Adam algorithm for functional form (#59171)
Summary: Adds Functional Interface for Sparse Adam Optimizer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/59171 Reviewed By: vincentqb Differential Revision: D29360582 Pulled By: iramazanli fbshipit-source-id: 5ceffd7f4b7abd1e0b758a5b8445abdf5555eba0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
963c983366
commit
7c2938bf67
@ -1,5 +1,5 @@
|
||||
import math
|
||||
import torch
|
||||
from . import _functional as F
|
||||
from .optimizer import Optimizer
|
||||
|
||||
|
||||
@ -64,59 +64,48 @@ class SparseAdam(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
state_steps = []
|
||||
eps = group['eps']
|
||||
lr = group['lr']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if not grad.is_sparse:
|
||||
raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
|
||||
if p.grad is not None:
|
||||
params_with_grad.append(p)
|
||||
if not p.grad.is_sparse:
|
||||
raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
state['step'] += 1
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
|
||||
grad = grad.coalesce() # the update is non-linear so indices must be unique
|
||||
grad_indices = grad._indices()
|
||||
grad_values = grad._values()
|
||||
size = grad.size()
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
def make_sparse(values):
|
||||
constructor = grad.new
|
||||
if grad_indices.dim() == 0 or values.dim() == 0:
|
||||
return constructor().resize_as_(grad)
|
||||
return constructor(grad_indices, values, size)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# old <- b * old + (1 - b) * new
|
||||
# <==> old += (1 - b) * (new - old)
|
||||
old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
|
||||
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
|
||||
exp_avg.add_(make_sparse(exp_avg_update_values))
|
||||
old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
|
||||
exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
|
||||
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
|
||||
|
||||
# Dense addition again is intended, avoiding another sparse_mask
|
||||
numer = exp_avg_update_values.add_(old_exp_avg_values)
|
||||
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
|
||||
denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])
|
||||
del exp_avg_update_values, exp_avg_sq_update_values
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.add_(make_sparse(-step_size * numer.div_(denom)))
|
||||
F.sparse_adam(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
eps=group['eps'])
|
||||
|
||||
return loss
|
||||
|
Reference in New Issue
Block a user