mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
When I use the named_parametes to modify the lr and weight decay, I will face a bug. Because the value of the named_parameters return is torch.nn.paramter.Parameter, not a generator of the Parameter.
153 lines
6.1 KiB
Python
153 lines
6.1 KiB
Python
from collections import defaultdict
|
|
|
|
import torch
|
|
from copy import deepcopy
|
|
from itertools import chain
|
|
from torch.autograd import Variable
|
|
|
|
required = object()
|
|
|
|
|
|
class Optimizer(object):
|
|
"""Base class for all optimizers.
|
|
|
|
Arguments:
|
|
params (iterable): an iterable of :class:`Variable` s or
|
|
:class:`dict` s. Specifies what Variables should be optimized.
|
|
defaults: (dict): a dict containing default values of optimization
|
|
options (used when a parameter group doesn't specify them).
|
|
"""
|
|
|
|
def __init__(self, params, defaults):
|
|
if isinstance(params, Variable) or torch.is_tensor(params):
|
|
raise TypeError("params argument given to the optimizer should be "
|
|
"an iterable of Variables or dicts, but got " +
|
|
torch.typename(params))
|
|
|
|
self.state = defaultdict(dict)
|
|
self.param_groups = list(params)
|
|
if len(self.param_groups) == 0:
|
|
raise ValueError("optimizer got an empty parameter list")
|
|
if not isinstance(self.param_groups[0], dict):
|
|
self.param_groups = [{'params': self.param_groups}]
|
|
|
|
param_set = set()
|
|
for group in self.param_groups:
|
|
if isinstance(group['params'], torch.autograd.Variable):
|
|
group['params'] = [group['params']]
|
|
else:
|
|
group['params'] = list(group['params'])
|
|
group_set = set(group['params'])
|
|
if not param_set.isdisjoint(group_set):
|
|
raise ValueError("some parameters appear in more than one "
|
|
"parameter group")
|
|
param_set.update(group_set)
|
|
|
|
for name, default in defaults.items():
|
|
for i, group in enumerate(self.param_groups):
|
|
if default is required and name not in group:
|
|
raise ValueError("parameter group " + str(i) + " didn't "
|
|
"specify a value of required optimization parameter " +
|
|
name)
|
|
else:
|
|
group.setdefault(name, default)
|
|
|
|
for group in self.param_groups:
|
|
for param in group['params']:
|
|
if not isinstance(param, Variable):
|
|
raise TypeError("optimizer can only optimize Variables, "
|
|
"but one of the params is " + torch.typename(param))
|
|
if not param.requires_grad:
|
|
raise ValueError("optimizing a parameter that doesn't "
|
|
"require gradients")
|
|
if not param.is_leaf:
|
|
raise ValueError("can't optimize a non-leaf Variable")
|
|
|
|
def __getstate__(self):
|
|
return {
|
|
'state': self.state,
|
|
'param_groups': self.param_groups,
|
|
}
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the optimizer as a :class:`dict`.
|
|
|
|
It contains two entries:
|
|
|
|
* state - a dict holding current optimization state. Its content
|
|
differs between optimizer classes.
|
|
* param_groups - a dict containig all parameter groups
|
|
"""
|
|
# Save ids instead of Variables
|
|
def pack_group(group):
|
|
packed = {k: v for k, v in group.items() if k != 'params'}
|
|
packed['params'] = [id(p) for p in group['params']]
|
|
return packed
|
|
param_groups = [pack_group(g) for g in self.param_groups]
|
|
# Remap state to use ids as keys
|
|
packed_state = {(id(k) if isinstance(k, Variable) else k): v
|
|
for k, v in self.state.items()}
|
|
return {
|
|
'state': packed_state,
|
|
'param_groups': param_groups,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the optimizer state.
|
|
|
|
Arguments:
|
|
state_dict (dict): optimizer state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
# deepcopy, to be consistent with module API
|
|
state_dict = deepcopy(state_dict)
|
|
# Validate the state_dict
|
|
groups = self.param_groups
|
|
saved_groups = state_dict['param_groups']
|
|
|
|
if len(groups) != len(saved_groups):
|
|
raise ValueError("loaded state dict has a different number of "
|
|
"parameter groups")
|
|
param_lens = (len(g['params']) for g in groups)
|
|
saved_lens = (len(g['params']) for g in saved_groups)
|
|
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
|
raise ValueError("loaded state dict contains a parameter group "
|
|
"that doesn't match the size of optimizer's group")
|
|
|
|
# Update the state
|
|
id_map = {old_id: p for old_id, p in
|
|
zip(chain(*(g['params'] for g in saved_groups)),
|
|
chain(*(g['params'] for g in groups)))}
|
|
state = {id_map.get(k, k): v for k, v in state_dict['state'].items()}
|
|
|
|
# Update parameter groups, setting their 'params' value
|
|
def update_group(group, new_group):
|
|
new_group['params'] = group['params']
|
|
return new_group
|
|
param_groups = [
|
|
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
|
self.__setstate__({'state': state, 'param_groups': param_groups})
|
|
|
|
def zero_grad(self):
|
|
"""Clears the gradients of all optimized :class:`Variable` s."""
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
if p.grad is not None:
|
|
if p.grad.volatile:
|
|
p.grad.data.zero_()
|
|
else:
|
|
data = p.grad.data
|
|
p.grad = Variable(data.new().resize_as_(data).zero_())
|
|
|
|
def step(self, closure):
|
|
"""Performs a single optimization step (parameter update).
|
|
|
|
Arguments:
|
|
closure (callable): A closure that reevaluates the model and
|
|
returns the loss. Optional for most optimizers.
|
|
"""
|
|
raise NotImplementedError
|