To refactor Sparse Adam algorithm for functional form (#59171)

Summary:
Adds Functional Interface for Sparse Adam Optimizer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59171

Reviewed By: vincentqb

Differential Revision: D29360582

Pulled By: iramazanli

fbshipit-source-id: 5ceffd7f4b7abd1e0b758a5b8445abdf5555eba0
This commit is contained in:
Ilqar Ramazanli
2021-06-25 06:34:03 -07:00
committed by Facebook GitHub Bot
parent 963c983366
commit 7c2938bf67
2 changed files with 93 additions and 49 deletions

View File

@ -1,5 +1,5 @@
import math
import torch
from . import _functional as F
from .optimizer import Optimizer
@ -64,59 +64,48 @@ class SparseAdam(Optimizer):
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_steps = []
eps = group['eps']
lr = group['lr']
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if not grad.is_sparse:
raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
if p.grad is not None:
params_with_grad.append(p)
if not p.grad.is_sparse:
raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
grads.append(p.grad)
state = self.state[p]
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['step'] += 1
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# Decay the first and second moment running average coefficient
# old <- b * old + (1 - b) * new
# <==> old += (1 - b) * (new - old)
old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
exp_avg.add_(make_sparse(exp_avg_update_values))
old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
# Dense addition again is intended, avoiding another sparse_mask
numer = exp_avg_update_values.add_(old_exp_avg_values)
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])
del exp_avg_update_values, exp_avg_sq_update_values
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.add_(make_sparse(-step_size * numer.div_(denom)))
F.sparse_adam(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
eps=group['eps'])
return loss