mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix For Requires Grad Infinite Loop (#18361)
Summary: Previously, we would continue to run requires grad on a loop body when the outputs and inputs disagreed. This adds a check so that we don't continue running if the results haven't changed since the last run. Fix for https://github.com/pytorch/pytorch/issues/18320 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18361 Differential Revision: D14584332 Pulled By: eellison fbshipit-source-id: 696b225f80a2036318540946428b525985a9e735
This commit is contained in:
committed by
Facebook Github Bot
parent
92c9fef860
commit
ca962f0f95
@ -4594,6 +4594,32 @@ a")
|
||||
|
||||
test_resize_as()
|
||||
|
||||
def test_requires_grad_loop(self):
|
||||
@torch.jit.script
|
||||
def test(x, y, z):
|
||||
# type: (Tensor, Tensor, int) -> Tensor
|
||||
for _ in range(z):
|
||||
x = y
|
||||
return x
|
||||
|
||||
# x requires grad, y does not
|
||||
# testing that requires grad analysis correctly exits, with its input
|
||||
# to the loop (x) requiring grad and its output to the loop not requiring grad
|
||||
# and the output of the node conservatively setting grad to true
|
||||
|
||||
inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10)
|
||||
test(*inps)
|
||||
|
||||
graph = test.graph_for(*inps)
|
||||
loop = graph.findNode("prim::Loop")
|
||||
loop_body = next(loop.blocks())
|
||||
loop_inputs = list(loop_body.inputs())
|
||||
loop_outputs = list(loop_body.outputs())
|
||||
|
||||
self.assertTrue(loop_inputs[1].requires_grad())
|
||||
self.assertFalse(loop_outputs[1].requires_grad())
|
||||
self.assertTrue(loop.output().requires_grad())
|
||||
|
||||
def test_view_shape_prop(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def test_view_shape_prop(a):
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/argument_spec.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -99,15 +99,23 @@ void PropagateRequiresGrad(Node* node) {
|
||||
fmap(node->inputs().slice(2), getRequiresGrad);
|
||||
std::vector<bool> body_outputs_require(node->outputs().size(), false);
|
||||
|
||||
while (body_inputs_require != body_outputs_require) {
|
||||
body_inputs_require =
|
||||
std::vector<bool> new_body_inputs_require = body_inputs_require;
|
||||
std::vector<bool> new_body_outputs_require = body_outputs_require;
|
||||
|
||||
// continue iterating until the results have converged
|
||||
do {
|
||||
body_inputs_require = new_body_inputs_require;
|
||||
body_outputs_require = new_body_outputs_require;
|
||||
|
||||
new_body_inputs_require =
|
||||
bitwiseOr(body_inputs_require, body_outputs_require);
|
||||
setRequiresGrad(
|
||||
body->param_node()->outputs().slice(1), body_inputs_require);
|
||||
body->param_node()->outputs().slice(1), new_body_inputs_require);
|
||||
PropagateRequiresGrad(body);
|
||||
body_outputs_require =
|
||||
new_body_outputs_require =
|
||||
fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
|
||||
}
|
||||
} while (new_body_inputs_require != body_inputs_require &&
|
||||
new_body_outputs_require != body_outputs_require);
|
||||
|
||||
setRequiresGrad(node, body_outputs_require);
|
||||
} else {
|
||||
@ -120,12 +128,10 @@ void PropagateRequiresGrad(Block* block) {
|
||||
PropagateRequiresGrad(node);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
|
||||
PropagateRequiresGrad(graph->block());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -385,6 +385,7 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
})
|
||||
.VS(copyMetadata)
|
||||
.VS(isTensor)
|
||||
.VS(requires_grad)
|
||||
.def("toIValue", [](Value& n) { return toIValue(&n); })
|
||||
.def("type", [](Value& v) { return v.type(); });
|
||||
#undef VS
|
||||
|
Reference in New Issue
Block a user