Fix For Requires Grad Infinite Loop (#18361)

Summary:
Previously, we would continue to run requires grad on a loop body when the outputs and inputs disagreed. This adds a check so that we don't continue running if the results haven't changed since the last run.

Fix for https://github.com/pytorch/pytorch/issues/18320
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18361

Differential Revision: D14584332

Pulled By: eellison

fbshipit-source-id: 696b225f80a2036318540946428b525985a9e735
This commit is contained in:
Elias Ellison
2019-03-24 14:28:22 -07:00
committed by Facebook Github Bot
parent 92c9fef860
commit ca962f0f95
3 changed files with 41 additions and 8 deletions

View File

@ -4594,6 +4594,32 @@ a")
test_resize_as()
def test_requires_grad_loop(self):
@torch.jit.script
def test(x, y, z):
# type: (Tensor, Tensor, int) -> Tensor
for _ in range(z):
x = y
return x
# x requires grad, y does not
# testing that requires grad analysis correctly exits, with its input
# to the loop (x) requiring grad and its output to the loop not requiring grad
# and the output of the node conservatively setting grad to true
inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10)
test(*inps)
graph = test.graph_for(*inps)
loop = graph.findNode("prim::Loop")
loop_body = next(loop.blocks())
loop_inputs = list(loop_body.inputs())
loop_outputs = list(loop_body.outputs())
self.assertTrue(loop_inputs[1].requires_grad())
self.assertFalse(loop_outputs[1].requires_grad())
self.assertTrue(loop.output().requires_grad())
def test_view_shape_prop(self):
cu = torch.jit.CompilationUnit('''
def test_view_shape_prop(a):

View File

@ -1,7 +1,7 @@
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/argument_spec.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <ATen/core/jit_type.h>
#include <vector>
@ -99,15 +99,23 @@ void PropagateRequiresGrad(Node* node) {
fmap(node->inputs().slice(2), getRequiresGrad);
std::vector<bool> body_outputs_require(node->outputs().size(), false);
while (body_inputs_require != body_outputs_require) {
body_inputs_require =
std::vector<bool> new_body_inputs_require = body_inputs_require;
std::vector<bool> new_body_outputs_require = body_outputs_require;
// continue iterating until the results have converged
do {
body_inputs_require = new_body_inputs_require;
body_outputs_require = new_body_outputs_require;
new_body_inputs_require =
bitwiseOr(body_inputs_require, body_outputs_require);
setRequiresGrad(
body->param_node()->outputs().slice(1), body_inputs_require);
body->param_node()->outputs().slice(1), new_body_inputs_require);
PropagateRequiresGrad(body);
body_outputs_require =
new_body_outputs_require =
fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
}
} while (new_body_inputs_require != body_inputs_require &&
new_body_outputs_require != body_outputs_require);
setRequiresGrad(node, body_outputs_require);
} else {
@ -120,12 +128,10 @@ void PropagateRequiresGrad(Block* block) {
PropagateRequiresGrad(node);
}
}
} // anonymous namespace
void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
PropagateRequiresGrad(graph->block());
}
} // namespace jit
} // namespace torch

View File

@ -385,6 +385,7 @@ void initPythonIRBindings(PyObject* module_) {
})
.VS(copyMetadata)
.VS(isTensor)
.VS(requires_grad)
.def("toIValue", [](Value& n) { return toIValue(&n); })
.def("type", [](Value& v) { return v.type(); });
#undef VS