Files
pytorch/torch/optim/optimizer.py
Luke Yeager e7c1e6a8e3 [pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!):

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i

Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.

Also configures flake8 to match pep8's behavior.

Also configures TravisCI to check the whole project for lint.
2017-01-28 01:15:51 +01:00

141 lines
5.6 KiB
Python

from collections import defaultdict
import torch
from copy import deepcopy
from itertools import chain
from torch.autograd import Variable
required = object()
class Optimizer(object):
"""Base class for all optimizers.
Arguments:
params (iterable): an iterable of :class:`Variable` s or
:class:`dict` s. Specifies what Variables should be optimized.
defaults: (dict): a dict containing default values of optimization
options (used when a parameter group doesn't specify them).
"""
def __init__(self, params, defaults):
if isinstance(params, Variable) or torch.is_tensor(params):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Variables or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict)
self.param_groups = list(params)
if len(self.param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(self.param_groups[0], dict):
self.param_groups = [{'params': self.param_groups}]
param_set = set()
for group in self.param_groups:
group['params'] = list(group['params'])
group_set = set(group['params'])
if not param_set.isdisjoint(group_set):
raise ValueError("some parameters appear in more than one "
"parameter group")
param_set.update(group_set)
for name, default in defaults.items():
for i, group in enumerate(self.param_groups):
if default is required and name not in group:
raise ValueError("parameter group " + str(i) + " didn't "
"specify a value of required optimization parameter "
+ name)
else:
group.setdefault(name, default)
for group in self.param_groups:
for param in group['params']:
if not isinstance(param, Variable):
raise TypeError("optimizer can only optimize Variables, "
"but one of the params is " + torch.typename(param))
if not param.requires_grad:
raise ValueError("optimizing a parameter that doesn't "
"require gradients")
if param.creator is not None:
raise ValueError("can't optimize a non-leaf Variable")
def __getstate__(self):
return {
'state': self.state,
'param_groups': self.param_groups,
}
def state_dict(self):
"""Returns the state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containig all parameter groups
"""
# Save ids instead of Variables
def pack_group(group):
packed = {k: v for k, v in group.items() if k != 'params'}
packed['params'] = [id(p) for p in group['params']]
return packed
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use ids as keys
packed_state = {(id(k) if isinstance(k, Variable) else k): v
for k, v in self.state.items()}
return {
'state': packed_state,
'param_groups': param_groups,
}
def load_state_dict(self, state_dict):
"""Loads the optimizer state.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
self.state = {id_map.get(k, k): v for k, v in state_dict['state'].items()}
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
self.param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
def zero_grad(self):
"""Clears the gradients of all optimized :class:`Variable` s."""
for group in self.param_groups:
for param in group['params']:
param.grad.data.zero_()
def step(self, closure):
"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
"""
raise NotImplementedError