Files
pytorch/torch/autograd/engine.py
Luke Yeager e7c1e6a8e3 [pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!):

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i

Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.

Also configures flake8 to match pep8's behavior.

Also configures TravisCI to check the whole project for lint.
2017-01-28 01:15:51 +01:00

86 lines
3.5 KiB
Python

from collections import deque, defaultdict
from torch._C import _ImperativeEngine as ImperativeEngine
from .variable import Variable
class BasicEngine(object):
def _compute_dependencies(self, function):
dependencies = defaultdict(int)
seen = {function}
queue = [function]
while len(queue) > 0:
fn = queue.pop()
for prev_fn, output_nr in fn.previous_functions:
if not prev_fn.requires_grad or isinstance(prev_fn, Variable):
continue
dependencies[prev_fn] += 1
if prev_fn not in seen:
queue.append(prev_fn)
seen.add(prev_fn)
return dependencies
def _free_backward_dependency(self, dependencies, prev_fn):
dependencies[prev_fn] -= 1
if dependencies[prev_fn] == 0:
del dependencies[prev_fn]
return True
return False
def _add_grad(self, need_copy, prev_grad, output_nr, d_prev_fn):
copy_id = (id(prev_grad), output_nr)
if not prev_grad[output_nr]:
prev_grad[output_nr] = d_prev_fn
need_copy.add(copy_id)
else:
grad_tensor = prev_grad[output_nr]
if copy_id in need_copy:
need_copy.remove(copy_id)
grad_tensor = grad_tensor.clone()
prev_grad[output_nr] = grad_tensor
grad_tensor.add_(d_prev_fn)
def run_backward(self, variable, grad, retain_variables):
if variable.creator is None:
variable._do_backward((grad,), retain_variables)
return
initial_grad = [None for _ in range(variable.creator.num_outputs)]
initial_grad[variable.output_nr] = grad
ready = deque([(variable.creator, initial_grad)])
not_ready = {}
need_copy = set()
dependencies = self._compute_dependencies(variable.creator)
while len(ready) > 0:
fn, grad = ready.pop()
grad_input = fn._do_backward(tuple(grad), retain_variables)
for (prev_fn, output_nr), d_prev_fn in zip(fn.previous_functions, grad_input):
if not prev_fn.requires_grad:
# TODO: check that d_prev_fn is None and warn otherwise
continue
if isinstance(prev_fn, Variable):
prev_fn._do_backward((d_prev_fn,), retain_variables)
continue
is_ready = self._free_backward_dependency(dependencies, prev_fn)
if is_ready:
if prev_fn in not_ready:
prev_grad = not_ready[prev_fn]
self._add_grad(need_copy, prev_grad, output_nr, d_prev_fn)
else:
if prev_fn.num_outputs != 1:
raise RuntimeError("one of the function outputs "
"wasn't used - this is an error not, but "
"it's going to be fixed soon")
prev_grad = (d_prev_fn,)
ready.appendleft((prev_fn, prev_grad))
else:
if prev_fn in not_ready:
prev_grad = not_ready[prev_fn]
else:
prev_grad = [None for _ in range(prev_fn.num_outputs)]
self._add_grad(need_copy, prev_grad, output_nr, d_prev_fn)
not_ready[prev_fn] = prev_grad