mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Here's the command I used to invoke autopep8 (in parallel!): git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i Several rules are ignored in setup.cfg. The goal is to let autopep8 handle everything which it can handle safely, and to disable any rules which are tricky or controversial to address. We may want to come back and re-enable some of these rules later, but I'm trying to make this patch as safe as possible. Also configures flake8 to match pep8's behavior. Also configures TravisCI to check the whole project for lint.
141 lines
5.6 KiB
Python
141 lines
5.6 KiB
Python
from collections import defaultdict
|
|
|
|
import torch
|
|
from copy import deepcopy
|
|
from itertools import chain
|
|
from torch.autograd import Variable
|
|
|
|
required = object()
|
|
|
|
|
|
class Optimizer(object):
|
|
"""Base class for all optimizers.
|
|
|
|
Arguments:
|
|
params (iterable): an iterable of :class:`Variable` s or
|
|
:class:`dict` s. Specifies what Variables should be optimized.
|
|
defaults: (dict): a dict containing default values of optimization
|
|
options (used when a parameter group doesn't specify them).
|
|
"""
|
|
|
|
def __init__(self, params, defaults):
|
|
if isinstance(params, Variable) or torch.is_tensor(params):
|
|
raise TypeError("params argument given to the optimizer should be "
|
|
"an iterable of Variables or dicts, but got " +
|
|
torch.typename(params))
|
|
|
|
self.state = defaultdict(dict)
|
|
self.param_groups = list(params)
|
|
if len(self.param_groups) == 0:
|
|
raise ValueError("optimizer got an empty parameter list")
|
|
if not isinstance(self.param_groups[0], dict):
|
|
self.param_groups = [{'params': self.param_groups}]
|
|
|
|
param_set = set()
|
|
for group in self.param_groups:
|
|
group['params'] = list(group['params'])
|
|
group_set = set(group['params'])
|
|
if not param_set.isdisjoint(group_set):
|
|
raise ValueError("some parameters appear in more than one "
|
|
"parameter group")
|
|
param_set.update(group_set)
|
|
|
|
for name, default in defaults.items():
|
|
for i, group in enumerate(self.param_groups):
|
|
if default is required and name not in group:
|
|
raise ValueError("parameter group " + str(i) + " didn't "
|
|
"specify a value of required optimization parameter "
|
|
+ name)
|
|
else:
|
|
group.setdefault(name, default)
|
|
|
|
for group in self.param_groups:
|
|
for param in group['params']:
|
|
if not isinstance(param, Variable):
|
|
raise TypeError("optimizer can only optimize Variables, "
|
|
"but one of the params is " + torch.typename(param))
|
|
if not param.requires_grad:
|
|
raise ValueError("optimizing a parameter that doesn't "
|
|
"require gradients")
|
|
if param.creator is not None:
|
|
raise ValueError("can't optimize a non-leaf Variable")
|
|
|
|
def __getstate__(self):
|
|
return {
|
|
'state': self.state,
|
|
'param_groups': self.param_groups,
|
|
}
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the optimizer as a :class:`dict`.
|
|
|
|
It contains two entries:
|
|
|
|
* state - a dict holding current optimization state. Its content
|
|
differs between optimizer classes.
|
|
* param_groups - a dict containig all parameter groups
|
|
"""
|
|
# Save ids instead of Variables
|
|
def pack_group(group):
|
|
packed = {k: v for k, v in group.items() if k != 'params'}
|
|
packed['params'] = [id(p) for p in group['params']]
|
|
return packed
|
|
param_groups = [pack_group(g) for g in self.param_groups]
|
|
# Remap state to use ids as keys
|
|
packed_state = {(id(k) if isinstance(k, Variable) else k): v
|
|
for k, v in self.state.items()}
|
|
return {
|
|
'state': packed_state,
|
|
'param_groups': param_groups,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the optimizer state.
|
|
|
|
Arguments:
|
|
state_dict (dict): optimizer state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
# deepcopy, to be consistent with module API
|
|
state_dict = deepcopy(state_dict)
|
|
# Validate the state_dict
|
|
groups = self.param_groups
|
|
saved_groups = state_dict['param_groups']
|
|
|
|
if len(groups) != len(saved_groups):
|
|
raise ValueError("loaded state dict has a different number of "
|
|
"parameter groups")
|
|
param_lens = (len(g['params']) for g in groups)
|
|
saved_lens = (len(g['params']) for g in saved_groups)
|
|
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
|
raise ValueError("loaded state dict contains a parameter group "
|
|
"that doesn't match the size of optimizer's group")
|
|
|
|
# Update the state
|
|
id_map = {old_id: p for old_id, p in
|
|
zip(chain(*(g['params'] for g in saved_groups)),
|
|
chain(*(g['params'] for g in groups)))}
|
|
self.state = {id_map.get(k, k): v for k, v in state_dict['state'].items()}
|
|
|
|
# Update parameter groups, setting their 'params' value
|
|
def update_group(group, new_group):
|
|
new_group['params'] = group['params']
|
|
return new_group
|
|
self.param_groups = [
|
|
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
|
|
|
def zero_grad(self):
|
|
"""Clears the gradients of all optimized :class:`Variable` s."""
|
|
for group in self.param_groups:
|
|
for param in group['params']:
|
|
param.grad.data.zero_()
|
|
|
|
def step(self, closure):
|
|
"""Performs a single optimization step (parameter update).
|
|
|
|
Arguments:
|
|
closure (callable): A closure that reevaluates the model and
|
|
returns the loss. Optional for most optimizers.
|
|
"""
|
|
raise NotImplementedError
|