Files
pytorch/torch/csrc/jit/frontend/exit_transforms.cpp
Michael Suo dbe850af5b [jit] do the code reorg (#33851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851

Rationale and context described in #33828.

Script to reproduce the move:
https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9
ghstack-source-id: 99079645

Test Plan: Make sure CI passes

Reviewed By: jamesr66a

Differential Revision: D20133869

fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
2020-02-27 13:02:51 -08:00

592 lines
20 KiB
C++

#include <torch/csrc/jit/frontend/exit_transforms.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/frontend/error_report.h>
namespace torch {
namespace jit {
// WILL states that a node/block must hit the exit, MIGHT that it may happen,
// WONT that it will not happen. THROWS states that a node/block always throws,
// and allows us to create better graphs by not conditionalizing execution
// when it is not necessary. It is an optimization; replacing it with WONT
// would preserve graph semantics.
enum class ExitStatus { WILL, MIGHT, WONT, THROWS };
enum class Transform { Returns, LoopContinuations };
// hasExited() indicates whether or not an exit has been hit.
// The ExitTransform pass maintains a false boolean false_val_ && a true boolean
// true_val_, and an uninitialized boolean throws_val_.
// if hasExited() == true_val_ then we have exited, if hasExited() == false_val_
// we have not, hasExited() == throws_val_ we have hit a block that throws.
// Otherwise, we might have exited.
// exitValues() are the values that we are propagating to a destination block.
// this is used for block outputs of loops and outputs of functions & closures
struct ExitPair : public std::pair<Value*, std::vector<Value*>> {
using pair::pair;
ExitPair(Value* exit_v, at::ArrayRef<Value*> exit_val_ref) {
std::vector<Value*> exit_vals;
for (Value* v : exit_val_ref) {
exit_vals.push_back(v);
}
AT_ASSERT(exit_v->type() == BoolType::get());
this->first = exit_v;
this->second = std::move(exit_vals);
}
Value* hasExited() const {
return this->first;
}
std::vector<Value*> exitValues() const {
return this->second;
}
};
/**
* This pass currently transforms the Graph so that all exit nodes targeting
* a block location are removed from the graph and unified.
* The exit node for breaks/continues is LoopContinuation, and the exit for
* Graphs & Closures is ReturnStmt.
*
* Once we hit an Exit Node, we do not execute any further instructions
* until the exit target has been reached.
*
* For blocks and control flow nodes that have an exit statement that may
* have been hit, we conditionalize all execution on a boolean value that
* indicates whether we have hit the exit, hasExited().
*
* The pass keeps tracks of blocks that always throw, so that we can construct
* simpler graphs. For example, if in one block of an if statement we return
* and in the other we throw, we can treat the node as always returning instead
* of conditionalizing execution in the remainder of the block.
*/
struct ExitTransformer {
ExitTransformer(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
WithInsertPoint guard(graph_->block()->nodes().front());
true_val_ = graph_->insertConstant(true);
false_val_ = graph_->insertConstant(false);
// this value will never be used, since we will always throw before it is
// accessed
throws_val_ = getUnitValue(BoolType::get());
};
void transformReturnStmts() {
current_exit_kind_ = prim::ReturnStmt;
transformExits(graph_->block());
}
void transformLoopContinuations() {
current_exit_kind_ = prim::LoopContinuation;
transformExits(graph_->block());
}
private:
ExitPair constructThrowsExitPair() {
return ExitPair(throws_val_, std::vector<Value*>({}));
}
ExitPair constructWontExitPair() {
return ExitPair(false_val_, std::vector<Value*>({}));
}
ExitPair constructWillExitPair(at::ArrayRef<Value*> exit_val_ref) {
return ExitPair(true_val_, exit_val_ref);
}
ExitStatus getExitStatus(ExitPair& exit_pair) {
Value* exit_v = exit_pair.hasExited();
if (exit_v == true_val_) {
return ExitStatus::WILL;
} else if (exit_v == false_val_) {
return ExitStatus::WONT;
} else if (exit_v == throws_val_) {
return ExitStatus::THROWS;
} else {
return ExitStatus::MIGHT;
}
}
static Symbol owningNodeKind(Block* block) {
if (block->owningNode()) {
return block->owningNode()->kind();
}
return Symbol();
}
static bool isGraphOrClosureBlock(Block* block) {
return block->owningNode() == nullptr ||
owningNodeKind(block) == prim::Function;
}
static void removeOutputs(Block* b) {
while (b->outputs().size() > 0) {
b->eraseOutput(0);
}
}
static void registerBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
for (Value* out : outs) {
b->registerOutput(out);
}
}
static void replaceBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
removeOutputs(b);
registerBlockOutputs(b, outs);
}
static void addIfOutputs(
Node* n,
at::ArrayRef<Value*> true_outs,
at::ArrayRef<Value*> false_outs) {
IfView if_view(n);
registerBlockOutputs(if_view.thenBlock(), true_outs);
registerBlockOutputs(if_view.elseBlock(), false_outs);
for (size_t i = 0; i < true_outs.size(); ++i) {
auto out_type =
unifyTypes(true_outs.at(i)->type(), false_outs.at(i)->type());
n->addOutput()->setType(*out_type);
}
}
// creates a vector of uninitialized values of the same type as the
// values_to_match
std::vector<Value*> matchValuesWithUnitialized(
at::ArrayRef<Value*> values_to_match) {
std::vector<Value*> match_values;
for (Value* val : values_to_match) {
match_values.push_back(getUnitValue(val->type()));
}
return match_values;
}
ExitPair transformLoop(Node* node) {
LoopView loop(node);
Block* body = loop.bodyBlock();
auto exit_pair = transformExits(body);
// if we're not exiting to outside the loop we don't need to do any work.
// since we may not enter the loop return WONT for the THROWS case.
if (getExitStatus(exit_pair) == ExitStatus::WONT ||
getExitStatus(exit_pair) == ExitStatus::THROWS) {
return constructWontExitPair();
}
// if we are, we need to update the loop continue condition so that
// we exit the loop if we've hit an exit
// and we need to propagate hasExited() and exitValues() outside the loop
// example:
// while i < 5:
// i += 1
// if j == 4:
// return 5
// -> becomes
//
// loop_continue = i < 5
// has_exited = false
// ret_val = uninitialized(int)
// while loop_continue:
// i += 1
// if j == 4:
// ret_val = 5
// has_exited = True
// else:
// ret_val = uninitialized(int)
// has_exited = False
// if has_exited:
// loop_continue = False
// else:
// loop_continue = i < 5
// update loop continuation condition so that we exit if we hit an exit
WithInsertPoint insert(body);
auto new_if = graph_->insertNode(graph_->create(prim::If, 0));
new_if->addInput(exit_pair.hasExited());
new_if->addBlock()->registerOutput(false_val_);
new_if->addBlock()->registerOutput(loop.nextCond());
auto new_condition = new_if->addOutput()->setType(BoolType::get());
loop.bodyBlock()->eraseOutput(0);
loop.bodyBlock()->insertOutput(0, new_condition);
// add hasExited() to loop outputs, we didn't exit if we didn't enter the
// loop
node->addInput(false_val_);
body->addInput()->setType(BoolType::get());
body->registerOutput(exit_pair.hasExited());
Value* new_has_exited = node->addOutput()->setType(BoolType::get());
// add exit values
for (Value* exit_value : exit_pair.exitValues()) {
auto typ = exit_value->type();
node->addInput(getUnitValue(typ));
node->addOutput()->setType(typ);
body->addInput()->setType(typ);
body->registerOutput(exit_value);
}
auto exit_vals = node->outputs().slice(
node->outputs().size() - exit_pair.exitValues().size());
return ExitPair(new_has_exited, exit_vals);
}
ExitStatus calcIfExitStatus(ExitStatus then_status, ExitStatus else_status) {
// if one branch throws, we can take the status of the other
if (then_status == ExitStatus::THROWS) {
return else_status;
} else if (else_status == ExitStatus::THROWS) {
return then_status;
}
if (then_status == ExitStatus::WONT && else_status == ExitStatus::WONT) {
return ExitStatus::WONT;
}
if (then_status == ExitStatus::WILL && else_status == ExitStatus::WILL) {
return ExitStatus::WILL;
}
return ExitStatus::MIGHT;
}
// Recursively transforms the if node
ExitPair transformIf(Node* node) {
auto then_block = node->blocks().at(0);
auto else_block = node->blocks().at(1);
auto then_pair = transformExits(then_block);
auto else_pair = transformExits(else_block);
auto then_status = getExitStatus(then_pair);
auto else_status = getExitStatus(else_pair);
auto if_status = calcIfExitStatus(then_status, else_status);
if (if_status == ExitStatus::THROWS) {
return constructThrowsExitPair();
}
if (if_status == ExitStatus::WONT) {
return constructWontExitPair();
}
// for the block that is not exitting, its' exit values will not get
// used so we create uninitialized values of the same type as the other
// block.
if (then_status == ExitStatus::WONT || then_status == ExitStatus::THROWS) {
std::vector<Value*> exit_vals =
matchValuesWithUnitialized(else_pair.exitValues());
then_pair = ExitPair(then_pair.hasExited(), exit_vals);
} else if (
else_status == ExitStatus::WONT || else_status == ExitStatus::THROWS) {
std::vector<Value*> exit_vals =
matchValuesWithUnitialized(then_pair.exitValues());
else_pair = ExitPair(else_pair.hasExited(), exit_vals);
}
Value* has_exited;
if (if_status == ExitStatus::WILL) {
// Need to maintain the invariant that if hasExited() == true_val_
// then we have exited.
has_exited = true_val_;
} else {
addIfOutputs(node, {then_pair.hasExited()}, {else_pair.hasExited()});
has_exited = node->outputs().at(node->outputs().size() - 1);
}
addIfOutputs(node, then_pair.exitValues(), else_pair.exitValues());
size_t num_exit_vals = then_pair.exitValues().size();
auto exit_vals =
node->outputs().slice(node->outputs().size() - num_exit_vals);
return ExitPair(has_exited, exit_vals);
}
// Guards the remaining nodes in the block with an if node that takes
// the has exited value as its conditional
ExitPair guardBlockNodes(
Block* block,
const ExitPair& exit_pair,
graph_node_list_iterator& iter) {
auto new_if = graph_->create(prim::If, 0)->insertBefore(*iter);
new_if->addInput(exit_pair.hasExited());
auto exit_block = new_if->addBlock();
auto guard_block = new_if->addBlock();
// Move all remaining nodes into the guard block
while (iter != block->nodes().end()) {
auto node = *iter++;
node->moveBefore(guard_block->return_node());
}
std::vector<Value*> exit_block_vals;
// after an exit, the only values that will get used
// are the hasExited() and exitValues(), so we match the existing
// block outputs with unitialized
exit_block_vals = matchValuesWithUnitialized(block->outputs());
// Set the new if to have the same outputs of the original block,
// then replace the original block outputs with new if's outputs
for (size_t i = 0; i < block->outputs().size(); ++i) {
exit_block->registerOutput(exit_block_vals.at(i));
guard_block->registerOutput(block->outputs().at(i));
new_if->addOutput()->setType(block->outputs().at(i)->type());
}
while (block->outputs().size() > 0) {
block->eraseOutput(0);
}
for (auto out : new_if->outputs()) {
block->registerOutput(out);
}
graph_->create(current_exit_kind_, {exit_pair.exitValues()}, 0)
->insertBefore(exit_block->return_node());
return transformIf(new_if);
}
// these nodes my have uses,
// such as in the case:
// if i == 1:
// break
// j = j + 1
// where the j + 1 value will be a block output, but since they will
// never be used, it is safe to replace them with unitialized value
void destroyNodeAfterExit(Node* n) {
for (auto output : n->outputs()) {
if (output->uses().size() > 0) {
output->replaceAllUsesWith(getUnitValue(output->type()));
}
}
n->destroy();
}
void deleteAfterExitNodes(Block* block, graph_node_list_iterator& iter) {
if (iter == block->nodes().end()) {
return;
}
WithInsertPoint insert(*block->nodes().begin());
// need to destroy in reverse order so nodes have no uses when destroyed
for (auto it = block->nodes().reverse().begin(); it != iter;) {
Node* n = *it++;
if (*it != block->return_node()) {
destroyNodeAfterExit(n);
}
}
destroyNodeAfterExit(*iter);
}
// if we're entering a Loop block & transforming LoopContinuations, or if
// we're entering a Closure/Graph block and we're transforming ReturnStmts,
// then we update target_block_ to be the new block.
// otherwise, target_block_ remains the same.
void updateTargetBlock(Block* block) {
if (owningNodeKind(block) == prim::Loop &&
current_exit_kind_ == prim::LoopContinuation) {
target_block_ = block;
} else if (
isGraphOrClosureBlock(block) &&
current_exit_kind_ == prim::ReturnStmt) {
target_block_ = block;
}
}
ExitPair transformExits(Block* block) {
Block* prev_target_block = target_block_;
updateTargetBlock(block);
ExitPair exit_pair = constructWontExitPair();
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* node = *it;
it++;
switch (node->kind()) {
case prim::RaiseException: {
exit_pair = constructThrowsExitPair();
} break;
case prim::ReturnStmt:
case prim::LoopContinuation: {
if (node->kind() == current_exit_kind_) {
exit_pair = constructWillExitPair(node->inputs());
node->destroy();
}
} break;
case prim::If: {
exit_pair = transformIf(node);
} break;
case prim::Function: {
// exits of closure declaration stay local to the closure
transformExits(node->blocks().at(0));
} break;
case prim::Loop: {
exit_pair = transformLoop(node);
} break;
}
// if we have hit a node that might exit, we need to conditionally execute
// all subsequent nodes in the block. if we've hit a node that will exit
// we can remove all subsequent nodes.
ExitStatus status = getExitStatus(exit_pair);
if (status == ExitStatus::WILL || status == ExitStatus::THROWS) {
deleteAfterExitNodes(block, it);
break;
}
if (status == ExitStatus::MIGHT) {
if (it != block->nodes().end()) {
exit_pair = guardBlockNodes(block, exit_pair, it);
}
break;
}
}
// if we are targeting this block, update the output values to the
// exit values. since the exit does not extend outside this block,
// update returned exit to false. then, reset the target_block to whatever
// it was previously
if (target_block_ == block) {
// if we might have exited, use the new exit values if we did exit,
// otherwise use the existing block outputs.
if (getExitStatus(exit_pair) == ExitStatus::MIGHT) {
auto new_if =
graph_->create(prim::If, 0)->insertBefore(block->return_node());
new_if->addBlock();
new_if->addBlock();
new_if->addInput(exit_pair.hasExited());
addIfOutputs(new_if, exit_pair.exitValues(), block->outputs());
replaceBlockOutputs(block, new_if->outputs());
} else if (getExitStatus(exit_pair) == ExitStatus::WILL) {
replaceBlockOutputs(block, exit_pair.exitValues());
}
// reset the exiting status. an exit should only reach its target block.
// e.g. a continue only affects most recent loop, return in closure
// does not affect enclosing graph.
// Exceptions do not propagate from Loops bc we might not enter the loop,
// and not from closures bc the Function node is a declaration and not
// an invocation.
exit_pair = constructWontExitPair();
}
target_block_ = prev_target_block;
return exit_pair;
}
Value* getUnitValue(const TypePtr& type) {
auto maybe_val = unit_values_.find(type);
if (maybe_val != unit_values_.end()) {
return maybe_val->second;
}
auto unit = graph_->createUninitialized(type)
->insertAfter(graph_->param_node())
->output();
unit_values_[type] = unit;
return unit;
}
// we create one uninitialized value per type, cache it here and reuse it
std::unordered_map<TypePtr, Value*> unit_values_;
// can either be LoopContinuation/ReturnStmt
Symbol current_exit_kind_;
Value* true_val_;
Value* false_val_;
Value* throws_val_;
// when we see current_exit_kind_, this is the block that the values are
// exiting to. For example when we are transforming LoopContinuations
// for i in range(5):
// while i < 3:
// continue
// break
// when we transform the for loop block, target_block_ will be set the for
// block. then, when we enter the while loop, target_block_ will be the while
// loop block. when we are done transforming the while it will be set back to
// the for block.
Block* target_block_ = nullptr;
std::shared_ptr<Graph> graph_;
};
// This pass takes in a graph where LoopContinuation & ReturnStmts exist in the
// graph and erases them in the graph, correctly setting block outputs.
// prim::LoopContinuation(*vals) means that the values are targeting the most
// recent loop block. prim::ReturnStmt(*vals) means that the values are
// targeting the most recent Closure or Graph Block. Once we hit an exit node,
// we do not execute any further instructions until the block exit reaches its
// destination. If we encounter a node that contains nested blocks that may
// have hit an exit node, such as an if statement that exits in one block
// and does not exit in the other, we use a boolean value to indicate if the
// exit has been hit or not. Then, we conditionalize further execution.
//
// Python example:
// while i < 5:
// if i == 3:
// i += 1
// continue
// i += 2
//
// -> transforms to:
//
// continue_loop = i < 5
// while continue_loop:
// if i == 3:
// i = i + 1
// continue_loop = i < 5
// did_exit = True
// if did_exit:
// pass
// else:
// i = i + 2
// continue_loop = i < 5
// IR as it enters pass:
// %36 : bool = aten::lt(%i.1, %3)
// %i : int = prim::Loop(%1, %36, %i.1)
// block0(%5 : int, %i.17 : int):
// %8 : bool = aten::eq(%i.17, %7)
// %i.16 : int = prim::If(%8)
// block0():
// %i.6 : int = aten::add(%i.17, %11)
// %33 : bool = aten::lt(%i.6, %3)
// = prim::LoopContinuation(%33, %i.6)
// -> (%i.6)
// block1():
// -> (%i.17)
// %i.13 : int = aten::add(%i.16, %19)
// %4 : bool = aten::lt(%i.13, %3)
// -> (%4, %i.13)
// return (%i)
//
// -> transforms to
//
// %false_val : bool = prim::Constant[value=0]()
// %true_val : bool = prim::Constant[value=1]()
// %40 : int = prim::Uninitialized()
// %39 : bool = prim::Uninitialized()
// %36 : bool = aten::lt(%i.1, %3)
// %i : int = prim::Loop(%1, %36, %i.1)
// block0(%5 : int, %i.17 : int):
// %8 : bool = aten::eq(%i.17, %7)
// %did_exit : bool, %continue_loop : bool, %43 : int, %i.16 : int =
// prim::If(%8)
// block0():
// %i.6 : int = aten::add(%i.17, %11)
// %33 : bool = aten::lt(%i.6, %3)
// -> (%true_val, %33, %i.6, %i.6)
// block1():
// -> (%false_val, %39, %40, %i.17)
// %44 : bool, %i : int = prim::If(%did_exit)
// block0():
// -> (%continue_loop, %43)
// block1():
// %i.13 : int = aten::add(%i.16, %19)
// %4 : bool = aten::lt(%i.13, %3)
// -> (%4, %i.13)
// -> (%44, %i)
void TransformExits(std::shared_ptr<Graph>& graph) {
ExitTransformer e_loop(graph);
e_loop.transformLoopContinuations();
ExitTransformer e_ret(graph);
e_ret.transformReturnStmts();
}
} // namespace jit
} // namespace torch