mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!): git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i Several rules are ignored in setup.cfg. The goal is to let autopep8 handle everything which it can handle safely, and to disable any rules which are tricky or controversial to address. We may want to come back and re-enable some of these rules later, but I'm trying to make this patch as safe as possible. Also configures flake8 to match pep8's behavior. Also configures TravisCI to check the whole project for lint.
This commit is contained in:
@ -21,8 +21,8 @@ class Optimizer(object):
|
||||
def __init__(self, params, defaults):
|
||||
if isinstance(params, Variable) or torch.is_tensor(params):
|
||||
raise TypeError("params argument given to the optimizer should be "
|
||||
"an iterable of Variables or dicts, but got " +
|
||||
torch.typename(params))
|
||||
"an iterable of Variables or dicts, but got " +
|
||||
torch.typename(params))
|
||||
|
||||
self.state = defaultdict(dict)
|
||||
self.param_groups = list(params)
|
||||
@ -37,15 +37,15 @@ class Optimizer(object):
|
||||
group_set = set(group['params'])
|
||||
if not param_set.isdisjoint(group_set):
|
||||
raise ValueError("some parameters appear in more than one "
|
||||
"parameter group")
|
||||
"parameter group")
|
||||
param_set.update(group_set)
|
||||
|
||||
for name, default in defaults.items():
|
||||
for i, group in enumerate(self.param_groups):
|
||||
if default is required and name not in group:
|
||||
raise ValueError("parameter group " + str(i) + " didn't "
|
||||
"specify a value of required optimization parameter "
|
||||
+ name)
|
||||
"specify a value of required optimization parameter "
|
||||
+ name)
|
||||
else:
|
||||
group.setdefault(name, default)
|
||||
|
||||
@ -53,10 +53,10 @@ class Optimizer(object):
|
||||
for param in group['params']:
|
||||
if not isinstance(param, Variable):
|
||||
raise TypeError("optimizer can only optimize Variables, "
|
||||
"but one of the params is " + torch.typename(param))
|
||||
"but one of the params is " + torch.typename(param))
|
||||
if not param.requires_grad:
|
||||
raise ValueError("optimizing a parameter that doesn't "
|
||||
"require gradients")
|
||||
"require gradients")
|
||||
if param.creator is not None:
|
||||
raise ValueError("can't optimize a non-leaf Variable")
|
||||
|
||||
@ -104,17 +104,17 @@ class Optimizer(object):
|
||||
|
||||
if len(groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of "
|
||||
"parameter groups")
|
||||
"parameter groups")
|
||||
param_lens = (len(g['params']) for g in groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Update the state
|
||||
id_map = {old_id: p for old_id, p in
|
||||
zip(chain(*(g['params'] for g in saved_groups)),
|
||||
chain(*(g['params'] for g in groups)))}
|
||||
zip(chain(*(g['params'] for g in saved_groups)),
|
||||
chain(*(g['params'] for g in groups)))}
|
||||
self.state = {id_map.get(k, k): v for k, v in state_dict['state'].items()}
|
||||
|
||||
# Update parameter groups, setting their 'params' value
|
||||
|
Reference in New Issue
Block a user