Files
pytorch/torch/csrc/jit/frontend/exit_transforms.cpp

845 lines
28 KiB
C++

#include <torch/csrc/jit/frontend/exit_transforms.h>
#include <ATen/core/jit_type.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
namespace torch::jit {
// WILL states that a node/block must hit the exit, MIGHT that it may happen,
// WONT that it will not happen. THROWS states that a node/block always throws,
// and allows us to create better graphs by not conditionalizing execution
// when it is not necessary. It is an optimization; replacing it with WONT
// would preserve graph semantics.
enum class ExitStatus { WILL, MIGHT, WONT, THROWS };
enum class Transform { Returns, LoopContinuations };
// hasExited() indicates whether or not an exit has been hit.
// The ExitTransform pass maintains a false boolean false_val_ && a true boolean
// true_val_, and an uninitialized boolean throws_val_.
// if hasExited() == true_val_ then we have exited, if hasExited() == false_val_
// we have not, hasExited() == throws_val_ we have hit a block that throws.
// Otherwise, we might have exited.
// exitValues() are the values that we are propagating to a destination block.
// this is used for block outputs of loops and outputs of functions & closures
struct ExitPair : public std::pair<Value*, std::vector<Value*>> {
using pair::pair;
ExitPair(Value* exit_v, at::ArrayRef<Value*> exit_val_ref) {
std::vector<Value*> exit_vals;
for (Value* v : exit_val_ref) {
exit_vals.push_back(v);
}
AT_ASSERT(exit_v->type() == BoolType::get());
this->first = exit_v;
this->second = std::move(exit_vals);
}
Value* hasExited() const {
return this->first;
}
std::vector<Value*> exitValues() const {
return this->second;
}
};
/**
* This pass currently transforms the Graph so that all exit nodes targeting
* a block location are removed from the graph and unified.
* The exit node for breaks/continues is LoopContinuation, and the exit for
* Graphs & Closures is ReturnStmt.
*
* Once we hit an Exit Node, we do not execute any further instructions
* until the exit target has been reached.
*
* For blocks and control flow nodes that have an exit statement that may
* have been hit, we conditionalize all execution on a boolean value that
* indicates whether we have hit the exit, hasExited().
*
* The pass keeps tracks of blocks that always throw, so that we can construct
* simpler graphs. For example, if in one block of an if statement we return
* and in the other we throw, we can treat the node as always returning instead
* of conditionalizing execution in the remainder of the block.
*/
struct ExitTransformer {
ExitTransformer(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
WithInsertPoint guard(graph_->block()->nodes().front());
true_val_ = graph_->insertConstant(true);
false_val_ = graph_->insertConstant(false);
// this value will never be used, since we will always throw before it is
// accessed
throws_val_ = getUnitValue(BoolType::get());
};
void transformReturnStmts() {
current_exit_kind_ = prim::ReturnStmt;
transformExits(graph_->block());
}
void transformLoopContinuations() {
current_exit_kind_ = prim::LoopContinuation;
transformExits(graph_->block());
}
private:
ExitPair constructThrowsExitPair() {
return ExitPair(throws_val_, std::vector<Value*>({}));
}
ExitPair constructWontExitPair() {
return ExitPair(false_val_, std::vector<Value*>({}));
}
ExitPair constructWillExitPair(at::ArrayRef<Value*> exit_val_ref) {
return ExitPair(true_val_, exit_val_ref);
}
ExitStatus getExitStatus(ExitPair& exit_pair) {
Value* exit_v = exit_pair.hasExited();
if (exit_v == true_val_) {
return ExitStatus::WILL;
} else if (exit_v == false_val_) {
return ExitStatus::WONT;
} else if (exit_v == throws_val_) {
return ExitStatus::THROWS;
} else {
return ExitStatus::MIGHT;
}
}
static Symbol owningNodeKind(Block* block) {
if (block->owningNode()) {
return block->owningNode()->kind();
}
return Symbol();
}
static bool isGraphOrClosureBlock(Block* block) {
return block->owningNode() == nullptr ||
owningNodeKind(block) == prim::Closure;
}
static void removeOutputs(Block* b) {
while (!b->outputs().empty()) {
b->eraseOutput(0);
}
}
static void registerBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
for (Value* out : outs) {
b->registerOutput(out);
}
}
static void replaceBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
removeOutputs(b);
registerBlockOutputs(b, outs);
}
static void addIfOutputs(
Node* n,
at::ArrayRef<Value*> true_outs,
at::ArrayRef<Value*> false_outs) {
IfView if_view(n);
registerBlockOutputs(if_view.thenBlock(), true_outs);
registerBlockOutputs(if_view.elseBlock(), false_outs);
for (const auto i : c10::irange(true_outs.size())) {
auto out_type = unifyTypes(
true_outs.at(i)->type(),
false_outs.at(i)->type(),
/*default_to_union=*/true);
n->addOutput()->setType(*out_type);
}
}
// creates a vector of uninitialized values of the same type as the
// values_to_match
std::vector<Value*> matchValuesWithUnitialized(
at::ArrayRef<Value*> values_to_match) {
std::vector<Value*> match_values;
for (Value* val : values_to_match) {
match_values.push_back(getUnitValue(val->type()));
}
return match_values;
}
ExitPair transformLoop(Node* node) {
LoopView loop(node);
Block* body = loop.bodyBlock();
auto exit_pair = transformExits(body);
// if we're not exiting to outside the loop we don't need to do any work.
// since we may not enter the loop return WONT for the THROWS case.
if (getExitStatus(exit_pair) == ExitStatus::WONT ||
getExitStatus(exit_pair) == ExitStatus::THROWS) {
return constructWontExitPair();
}
// if we are, we need to update the loop continue condition so that
// we exit the loop if we've hit an exit
// and we need to propagate hasExited() and exitValues() outside the loop
// example:
// while i < 5:
// i += 1
// if j == 4:
// return 5
// -> becomes
//
// loop_continue = i < 5
// has_exited = false
// ret_val = uninitialized(int)
// while loop_continue:
// i += 1
// if j == 4:
// ret_val = 5
// has_exited = True
// else:
// ret_val = uninitialized(int)
// has_exited = False
// if has_exited:
// loop_continue = False
// else:
// loop_continue = i < 5
// update loop continuation condition so that we exit if we hit an exit
WithInsertPoint insert(body);
auto new_if = graph_->insertNode(graph_->create(prim::If, 0));
new_if->addInput(exit_pair.hasExited());
new_if->addBlock()->registerOutput(false_val_);
new_if->addBlock()->registerOutput(loop.nextCond());
auto new_condition = new_if->addOutput()->setType(BoolType::get());
loop.bodyBlock()->eraseOutput(0);
loop.bodyBlock()->insertOutput(0, new_condition);
// add hasExited() to loop outputs, we didn't exit if we didn't enter the
// loop
node->addInput(false_val_);
body->addInput()->setType(BoolType::get());
body->registerOutput(exit_pair.hasExited());
Value* new_has_exited = node->addOutput()->setType(BoolType::get());
// add exit values
for (Value* exit_value : exit_pair.exitValues()) {
auto typ = exit_value->type();
node->addInput(getUnitValue(typ));
node->addOutput()->setType(typ);
body->addInput()->setType(typ);
body->registerOutput(exit_value);
}
auto exit_vals = node->outputs().slice(
node->outputs().size() - exit_pair.exitValues().size());
return ExitPair(new_has_exited, exit_vals);
}
ExitStatus calcIfExitStatus(ExitStatus then_status, ExitStatus else_status) {
// if one branch throws, we can take the status of the other
if (then_status == ExitStatus::THROWS) {
return else_status;
} else if (else_status == ExitStatus::THROWS) {
return then_status;
}
if (then_status == ExitStatus::WONT && else_status == ExitStatus::WONT) {
return ExitStatus::WONT;
}
if (then_status == ExitStatus::WILL && else_status == ExitStatus::WILL) {
return ExitStatus::WILL;
}
return ExitStatus::MIGHT;
}
// Recursively transforms the if node
ExitPair transformIf(Node* node) {
auto then_block = node->blocks().at(0);
auto else_block = node->blocks().at(1);
auto then_pair = transformExits(then_block);
auto else_pair = transformExits(else_block);
auto then_status = getExitStatus(then_pair);
auto else_status = getExitStatus(else_pair);
auto if_status = calcIfExitStatus(then_status, else_status);
if (if_status == ExitStatus::THROWS) {
return constructThrowsExitPair();
}
if (if_status == ExitStatus::WONT) {
return constructWontExitPair();
}
// The exit values of the block that is not exiting will not get
// used, so we create uninitialized values of the same type as the other
// block.
if (then_status == ExitStatus::WONT || then_status == ExitStatus::THROWS) {
std::vector<Value*> exit_vals =
matchValuesWithUnitialized(else_pair.exitValues());
then_pair = ExitPair(then_pair.hasExited(), exit_vals);
} else if (
else_status == ExitStatus::WONT || else_status == ExitStatus::THROWS) {
std::vector<Value*> exit_vals =
matchValuesWithUnitialized(then_pair.exitValues());
else_pair = ExitPair(else_pair.hasExited(), exit_vals);
}
Value* has_exited = nullptr;
if (if_status == ExitStatus::WILL) {
// Need to maintain the invariant that if hasExited() == true_val_
// then we have exited.
has_exited = true_val_;
} else {
addIfOutputs(node, {then_pair.hasExited()}, {else_pair.hasExited()});
has_exited = node->outputs().at(node->outputs().size() - 1);
}
addIfOutputs(node, then_pair.exitValues(), else_pair.exitValues());
size_t num_exit_vals = then_pair.exitValues().size();
auto exit_vals =
node->outputs().slice(node->outputs().size() - num_exit_vals);
return ExitPair(has_exited, exit_vals);
}
// Recursively transforms the With node.
ExitPair transformWith(Node* node) {
auto body_block = node->blocks().at(0);
auto body_pair = transformExits(body_block);
return body_pair;
}
// Guards the remaining nodes in the block with an if node that takes
// the has exited value as its conditional
ExitPair guardBlockNodes(
Block* block,
const ExitPair& exit_pair,
graph_node_list_iterator& iter) {
auto new_if = graph_->create(prim::If, 0)->insertBefore(*iter);
new_if->addInput(exit_pair.hasExited());
auto exit_block = new_if->addBlock();
auto guard_block = new_if->addBlock();
// Move all remaining nodes into the guard block
while (iter != block->nodes().end()) {
auto node = *iter++;
node->moveBefore(guard_block->return_node());
}
std::vector<Value*> exit_block_vals;
// after an exit, the only values that will get used
// are the hasExited() and exitValues(), so we match the existing
// block outputs with unitialized
exit_block_vals = matchValuesWithUnitialized(block->outputs());
// Set the new if to have the same outputs of the original block,
// then replace the original block outputs with new if's outputs
for (size_t i = 0; i < block->outputs().size(); ++i) {
exit_block->registerOutput(exit_block_vals.at(i));
guard_block->registerOutput(block->outputs().at(i));
new_if->addOutput()->setType(block->outputs().at(i)->type());
}
while (!block->outputs().empty()) {
block->eraseOutput(0);
}
for (auto out : new_if->outputs()) {
block->registerOutput(out);
}
graph_->create(current_exit_kind_, {exit_pair.exitValues()}, 0)
->insertBefore(exit_block->return_node());
return transformIf(new_if);
}
// these nodes my have uses,
// such as in the case:
// if i == 1:
// break
// j = j + 1
// where the j + 1 value will be a block output, but since they will
// never be used, it is safe to replace them with unitialized value
void destroyNodeAfterExit(Node* n) {
for (auto output : n->outputs()) {
if (!output->uses().empty()) {
output->replaceAllUsesWith(getUnitValue(output->type()));
}
}
n->destroy();
}
void deleteAfterExitNodes(Block* block, graph_node_list_iterator& iter) {
if (iter == block->nodes().end()) {
return;
}
WithInsertPoint insert(*block->nodes().begin());
// need to destroy in reverse order so nodes have no uses when destroyed
for (auto it = block->nodes().reverse().begin(); it != iter;) {
Node* n = *it++;
if (*it != block->return_node()) {
destroyNodeAfterExit(n);
}
}
destroyNodeAfterExit(*iter);
}
// if we're entering a Loop block & transforming LoopContinuations, or if
// we're entering a Closure/Graph block and we're transforming ReturnStmts,
// then we update target_block_ to be the new block.
// otherwise, target_block_ remains the same.
void updateTargetBlock(Block* block) {
if (owningNodeKind(block) == prim::Loop &&
// NOLINTNEXTLINE(bugprone-branch-clone)
current_exit_kind_ == prim::LoopContinuation) {
target_block_ = block;
} else if (
isGraphOrClosureBlock(block) &&
current_exit_kind_ == prim::ReturnStmt) {
target_block_ = block;
}
}
ExitPair transformExits(Block* block) {
Block* prev_target_block = target_block_;
updateTargetBlock(block);
ExitPair exit_pair = constructWontExitPair();
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* node = *it;
it++;
switch (node->kind()) {
case prim::RaiseException: {
exit_pair = constructThrowsExitPair();
} break;
case prim::ReturnStmt:
case prim::LoopContinuation: {
if (node->kind() == current_exit_kind_) {
exit_pair = constructWillExitPair(node->inputs());
node->destroy();
}
} break;
case prim::If: {
exit_pair = transformIf(node);
} break;
case prim::With: {
exit_pair = transformWith(node);
} break;
case prim::Closure: {
// exits of closure declaration stay local to the closure
transformExits(node->blocks().at(0));
} break;
case prim::Loop: {
exit_pair = transformLoop(node);
} break;
}
// if we have hit a node that might exit, we need to conditionally execute
// all subsequent nodes in the block. if we've hit a node that will exit
// we can remove all subsequent nodes.
ExitStatus status = getExitStatus(exit_pair);
if (status == ExitStatus::WILL || status == ExitStatus::THROWS) {
deleteAfterExitNodes(block, it);
break;
}
if (status == ExitStatus::MIGHT) {
if (it != block->nodes().end()) {
exit_pair = guardBlockNodes(block, exit_pair, it);
}
break;
}
}
// if we are targeting this block, update the output values to the
// exit values. since the exit does not extend outside this block,
// update returned exit to false. then, reset the target_block to whatever
// it was previously
if (target_block_ == block) {
// if we might have exited, use the new exit values if we did exit,
// otherwise use the existing block outputs.
if (getExitStatus(exit_pair) == ExitStatus::MIGHT) {
auto new_if =
graph_->create(prim::If, 0)->insertBefore(block->return_node());
new_if->addBlock();
new_if->addBlock();
new_if->addInput(exit_pair.hasExited());
addIfOutputs(new_if, exit_pair.exitValues(), block->outputs());
replaceBlockOutputs(block, new_if->outputs());
} else if (getExitStatus(exit_pair) == ExitStatus::WILL) {
replaceBlockOutputs(block, exit_pair.exitValues());
}
// reset the exiting status. an exit should only reach its target block.
// e.g. a continue only affects most recent loop, return in closure
// does not affect enclosing graph.
// Exceptions do not propagate from Loops bc we might not enter the loop,
// and not from closures bc the Function node is a declaration and not
// an invocation.
exit_pair = constructWontExitPair();
}
target_block_ = prev_target_block;
return exit_pair;
}
Value* getUnitValue(const TypePtr& type) {
auto maybe_val = unit_values_.find(type);
if (maybe_val != unit_values_.end()) {
return maybe_val->second;
}
auto unit = graph_->createUninitialized(type)
->insertAfter(graph_->param_node())
->output();
unit_values_[type] = unit;
return unit;
}
// we create one uninitialized value per type, cache it here and reuse it
std::unordered_map<TypePtr, Value*> unit_values_;
// can either be LoopContinuation/ReturnStmt
Symbol current_exit_kind_;
Value* true_val_;
Value* false_val_;
Value* throws_val_;
// when we see current_exit_kind_, this is the block that the values are
// exiting to. For example when we are transforming LoopContinuations
// for i in range(5):
// while i < 3:
// continue
// break
// when we transform the for loop block, target_block_ will be set the for
// block. then, when we enter the while loop, target_block_ will be the while
// loop block. when we are done transforming the while it will be set back to
// the for block.
Block* target_block_ = nullptr;
std::shared_ptr<Graph> graph_;
};
static bool inlineConsecutiveIfs(Node* node) {
if (node->kind() != prim::If || node->next()->kind() != prim::If) {
return false;
}
IfView first_if(node);
IfView second_if(node->next());
// the second if must depend on a value outputted in the first if for us to
// inline the second if
if (second_if.cond()->node() != node) {
return false;
}
// both blocks must output a constant value for us to inline, and those values
// must be different. if the values are the same, then the subsequent if node
// will get constant prop'd away, and inlining it into the first node would
// double code size
auto input_offset = second_if.cond()->offset();
auto maybe_then_value = toIValue(first_if.thenOutputs().at(input_offset));
auto maybe_else_value = toIValue(first_if.elseOutputs().at(input_offset));
if (!maybe_then_value || !maybe_else_value ||
maybe_then_value->toBool() == maybe_else_value->toBool()) {
return false;
}
bool then_value = maybe_then_value->toBool();
bool else_value = maybe_else_value->toBool();
for (const auto i : c10::irange(2)) {
Block* first_if_block = nullptr;
Block* second_if_block = nullptr;
if (i == 0) {
first_if_block = first_if.thenBlock();
second_if_block =
then_value ? second_if.thenBlock() : second_if.elseBlock();
} else {
first_if_block = first_if.elseBlock();
second_if_block =
else_value ? second_if.thenBlock() : second_if.elseBlock();
;
}
// we need to replace values that were used in the second if that were
// outputs of the first if with the equivalent value in the scope of the
// block we're copying into
auto value_map = [&](Value* v) {
if (v->node() != first_if.node()) {
return v;
}
auto offset = v->offset();
return first_if_block->outputs().at(offset);
};
// clone from also copies block outputs from second_if_block onto
// first_if_block
first_if_block->cloneFrom(second_if_block, value_map);
}
for (Value* output : second_if.outputs()) {
auto new_out = first_if.node()->addOutput()->copyMetadata(output);
output->replaceAllUsesWith(new_out);
}
second_if.node()->destroy();
return true;
}
// After an early return, we conditionalize all further execution
// This means code like the following:
// if x:
// return 1
// return 2
// Gets generated as one if statement checking `if x`, and then a second if
// statement that conditionalizes execution. We can rewrite cases like these
// into one if statement, so that the above examples gets rewritten to look
// like: if x:
// return 1
// else:
// return 2
static void inlineConsecutiveIfs(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
for (Block* b : it->blocks()) {
inlineConsecutiveIfs(b);
}
// if we fused two ifs, we need to check current node and new next node
if (!inlineConsecutiveIfs(*it)) {
it++;
}
}
}
// Adds prim::With nodes to a graph to help handle early exits between
// prim::Enter and prim::Exit nodes. More specifically, it transforms
// IR that looks like this:
//
// %a = prim::Enter(%b)
// <code>
// %c = prim::Exit(%b)
//
// to this:
//
// %a = prim::Enter(%b)
// = prim::With()
// block0():
// <code>
// -> ()
// block1():
// %c = prim::Exit(%b)
// -> ()
//
static void convertEnterExitNodesToWithBlocks(std::shared_ptr<Graph>& graph) {
// First, find all Enter-Exit pairs up front to avoid iterator invalidation
// issues later when moving nodes around. Do this by iterating through the
// nodes of the graph while keeping a stack of encountered Enter nodes. Each
// time an Exit node is seen, its corresponding Enter node must be at the
// top of the stack. Pop it and record the pair.
std::vector<std::pair<Node*, Node*>> enter_exit_pairs;
std::vector<Node*> enter_node_stack;
DepthFirstGraphNodeIterator it(graph);
Node* node = it.next();
while (node) {
if (node->kind() == prim::Enter) {
enter_node_stack.emplace_back(node);
} else if (node->kind() == prim::Exit) {
// enter_node_stack should not be empty.
TORCH_INTERNAL_ASSERT(!enter_node_stack.empty());
// The input to this Exit node should be the same as that of the Enter
// node on the top of the enter_node_stack.
TORCH_INTERNAL_ASSERT(
enter_node_stack.back()->input(0) == node->input(0));
// Record the pair.
enter_exit_pairs.emplace_back(enter_node_stack.back(), node);
enter_node_stack.pop_back();
}
node = it.next();
}
// The stack should be empty; an Exit should have been found for every Enter.
TORCH_INTERNAL_ASSERT(enter_node_stack.empty());
// Now, add a With block for each Enter-Exit pair. The innermost pairs were
// found first, so they will be converted first.
for (auto& pair : enter_exit_pairs) {
Node* enter = pair.first;
Node* exit = pair.second;
auto* with = graph->create(prim::With, /*num_outputs=*/0);
auto* body_block = with->addBlock();
auto* exit_block = with->addBlock();
// Insert the With after the Enter.
Node* cur = enter->next();
Node* insert_point = body_block->param_node();
// Move all of the nodes between the Enter and Exit into the body block.
while (cur != exit) {
auto* next = cur->next();
cur->moveAfter(insert_point);
insert_point = insert_point->next();
cur = next;
}
// Move the Exit node into the exit block.
exit->moveAfter(exit_block->param_node());
with->insertAfter(enter);
}
}
// Removes prim::With nodes from a graph. More specifically, it transforms
// IR that looks like this:
//
// %a = prim::Enter(%b)
// = prim::With()
// block0():
// <code>
// -> ()
// block1():
// %c = prim::Exit(%b)
// ->()
//
// to this:
// %a = prim::Enter(%b)
// <code>
// %c = prim::Exit(%b)
//
static void convertWithBlocksToEnterExitNodes(std::shared_ptr<Graph>& graph) {
// First, find all With blocks to avoid iterator invalidation issues when
// moving nodes around later.
std::vector<Node*> with_nodes;
DepthFirstGraphNodeIterator it(graph);
Node* node = it.next();
while (node) {
if (node->kind() == prim::With) {
with_nodes.emplace_back(node);
}
node = it.next();
}
// For each With node:
for (auto& node : with_nodes) {
auto* body_block = node->blocks().at(0);
auto* exit_block = node->blocks().at(1);
std::vector<Node*> to_append;
// Record all nodes that need to be appended after the Enter that precedes
// the With block to avoid iterator invalidation issues later when moving
// nodes around.
for (auto body_node : body_block->nodes()) {
to_append.emplace_back(body_node);
}
for (auto exit_node : exit_block->nodes()) {
to_append.emplace_back(exit_node);
}
Node* cur = node->prev();
// Move all nodes inside the with block outside of it.
for (auto& node : to_append) {
node->moveAfter(cur);
cur = node;
}
node->destroy();
}
}
// This pass takes in a graph where LoopContinuation & ReturnStmts exist in the
// graph and erases them in the graph, correctly setting block outputs.
// prim::LoopContinuation(*vals) means that the values are targeting the most
// recent loop block. prim::ReturnStmt(*vals) means that the values are
// targeting the most recent Closure or Graph Block. Once we hit an exit node,
// we do not execute any further instructions until the block exit reaches its
// destination. If we encounter a node that contains nested blocks that may
// have hit an exit node, such as an if statement that exits in one block
// and does not exit in the other, we use a boolean value to indicate if the
// exit has been hit or not. Then, we conditionalize further execution.
//
// Python example:
// while i < 5:
// if i == 3:
// i += 1
// continue
// i += 2
//
// -> transforms to:
//
// continue_loop = i < 5
// while continue_loop:
// if i == 3:
// i = i + 1
// continue_loop = i < 5
// did_exit = True
// if did_exit:
// pass
// else:
// i = i + 2
// continue_loop = i < 5
// IR as it enters pass:
// %36 : bool = aten::lt(%i.1, %3)
// %i : int = prim::Loop(%1, %36, %i.1)
// block0(%5 : int, %i.17 : int):
// %8 : bool = aten::eq(%i.17, %7)
// %i.16 : int = prim::If(%8)
// block0():
// %i.6 : int = aten::add(%i.17, %11)
// %33 : bool = aten::lt(%i.6, %3)
// = prim::LoopContinuation(%33, %i.6)
// -> (%i.6)
// block1():
// -> (%i.17)
// %i.13 : int = aten::add(%i.16, %19)
// %4 : bool = aten::lt(%i.13, %3)
// -> (%4, %i.13)
// return (%i)
//
// -> transforms to
//
// %false_val : bool = prim::Constant[value=0]()
// %true_val : bool = prim::Constant[value=1]()
// %40 : int = prim::Uninitialized()
// %39 : bool = prim::Uninitialized()
// %36 : bool = aten::lt(%i.1, %3)
// %i : int = prim::Loop(%1, %36, %i.1)
// block0(%5 : int, %i.17 : int):
// %8 : bool = aten::eq(%i.17, %7)
// %did_exit : bool, %continue_loop : bool, %43 : int, %i.16 : int =
// prim::If(%8)
// block0():
// %i.6 : int = aten::add(%i.17, %11)
// %33 : bool = aten::lt(%i.6, %3)
// -> (%true_val, %33, %i.6, %i.6)
// block1():
// -> (%false_val, %39, %40, %i.17)
// %44 : bool, %i : int = prim::If(%did_exit)
// block0():
// -> (%continue_loop, %43)
// block1():
// %i.13 : int = aten::add(%i.16, %19)
// %4 : bool = aten::lt(%i.13, %3)
// -> (%4, %i.13)
// -> (%44, %i)
void TransformExits(std::shared_ptr<Graph>& graph) {
convertEnterExitNodesToWithBlocks(graph);
ExitTransformer e_loop(graph);
e_loop.transformLoopContinuations();
ExitTransformer e_ret(graph);
e_ret.transformReturnStmts();
inlineConsecutiveIfs(graph->block());
convertWithBlocksToEnterExitNodes(graph);
}
} // namespace torch::jit