Files
pytorch/torch/csrc/jit/frontend/convert_to_ssa.cpp
Ansley Ussery c619892482 Fix errata (#49903)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49903

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D25718411

Pulled By: ansley

fbshipit-source-id: 0cc365c5a53077752dc1c5a5c4a65b873baa3604
2020-12-28 20:40:41 -08:00

350 lines
11 KiB
C++

#include <torch/csrc/jit/frontend/convert_to_ssa.h>
#include <torch/csrc/jit/frontend/exit_transforms.h>
#include <torch/csrc/jit/frontend/inline_loop_condition.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/mini_environment.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
namespace torch {
namespace jit {
// At the beginning of the pass the Graph has already undergone type checking,
// and writes or reads to a variable are emitted as Loads and Stores in the
// graph.
// a = 1
// print(a)
// is represented as:
// %a.1 : int = prim::Constant[value=1]()
// prim::Store[name="a"](%a.1)
// %a : int = prim::Load[name="a"]()
// prim::Print(%a)
//
// First, this pass recursively adds the Loads & Stores to control flow nodes
// Then the graph is converted to SSA form.
using ValueEnvironment = MiniEnvironment<Value*>;
using TypeEnvironment = MiniEnvironment<TypePtr>;
// Adds Loads & Stores to Loops & Ifs
struct ControlFlowLoadStores {
static void addBlockInput(
Block* b,
const TypePtr& type,
const std::string& name) {
auto g = b->owningGraph();
g->createStore(name, b->addInput(name)->setType(type))
->insertAfter(b->param_node());
}
static void addBlockOutput(
Block* exit_block,
const TypePtr& type,
const std::string& name) {
WithInsertPoint insert(exit_block);
auto g = exit_block->owningGraph();
auto block_exit = g->insertNode(g->createLoad(name, type))->output();
exit_block->registerOutput(block_exit);
}
static void addNodeOutput(
Node* n,
const TypePtr& type,
const std::string& name) {
auto out = n->addOutput()->setType(type);
if (meaningfulName(name)) {
out->setDebugName(name);
}
auto g = n->owningGraph();
g->createStore(name, out)->insertAfter(n);
}
static void addNodeInput(
Node* n,
const TypePtr& type,
const std::string& name) {
auto g = n->owningGraph();
auto inp = g->createLoad(name, type)->insertBefore(n)->output();
n->addInput(inp);
}
void addIfLoadStores(Node* n) {
auto true_block = n->blocks().at(0);
auto false_block = n->blocks().at(1);
auto true_vars = addControlFlowLoadStores(true_block);
auto false_vars = addControlFlowLoadStores(false_block);
std::set<std::string> mutated_variables;
for (auto& v : true_vars->definedVariables()) {
if (false_vars->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
for (auto& v : false_vars->definedVariables()) {
if (true_vars->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
// Following the same logic as emitIfElseBlocks in ir_emitter.cpp,
// we emit a node output if the variable is defined in each block
// and the types of each block can be unified
for (const auto& x : mutated_variables) {
auto true_type = true_vars->findInAnyFrame(x);
auto false_type = false_vars->findInAnyFrame(x);
auto unified = unifyTypes(true_type, false_type);
if (!unified) {
continue;
}
addBlockOutput(true_block, true_type, x);
addBlockOutput(false_block, false_type, x);
addNodeOutput(n, *unified, x);
}
}
// loop_carried_outputs* = Loop(max_trip_count, start_condition,
// loop_carried_inputs*)
// block0(loop_counter, loop_carried_block*) {
// <body>
// -> (continue_condition, loop_carried_block_outputs*)
// }
// all loop_carried_... lists are the same length and represent the value of
// loop-carried variables whose definitions are updated as the loop executes
// in a way that ensure single static assignment.
void addLoopLoadStores(Node* n) {
auto body_block = n->blocks().at(0);
auto loop_vars = addControlFlowLoadStores(body_block);
for (const auto& name : loop_vars->definedVariables()) {
// if the variable local to the loop body, then
// we do not need a loop carried variable for it
auto parent_type = environment_stack->findInAnyFrame(name);
if (!parent_type) {
continue;
}
// since the loop may execute 0 or many times, the output types
// of the loop and the input loop carried dependencies are conservatively
// the union of the output of the body and the input to the loop
auto block_type = loop_vars->findInThisFrame(name);
auto unified_type = unifyTypes(parent_type, block_type).value();
// Insert a store at the beginning of the loop block, so that all
// loads of the variable will use the loop carried value
addNodeInput(n, parent_type, name);
addBlockInput(body_block, unified_type, name);
addBlockOutput(body_block, block_type, name);
addNodeOutput(n, unified_type, name);
}
}
std::shared_ptr<TypeEnvironment> addControlFlowLoadStores(Block* block) {
pushFrame(block);
for (Node* n : block->nodes()) {
switch (n->kind()) {
case prim::If: {
addIfLoadStores(n);
} break;
case prim::Loop: {
addLoopLoadStores(n);
} break;
case prim::Closure: {
for (auto b : n->blocks()) {
addControlFlowLoadStores(b);
}
} break;
case prim::Store: {
environment_stack->setVar(n->s(attr::name), n->input()->type());
} break;
case prim::ComprehensionScope: {
addControlFlowLoadStores(n->blocks().at(0));
} break;
}
}
return popFrame();
}
void pushFrame(Block* b) {
environment_stack = std::make_shared<TypeEnvironment>(b, environment_stack);
}
std::shared_ptr<TypeEnvironment> popFrame() {
auto old_frame = environment_stack;
environment_stack = environment_stack->next;
return old_frame;
}
void run(std::shared_ptr<Graph>& graph) {
addControlFlowLoadStores(graph->block());
}
std::shared_ptr<TypeEnvironment> environment_stack = nullptr;
};
// Given a graph where 1) outputs have been added to control flow nodes and
// 2) loads and stores are represented in the graph, erase the Loads & Stores.
struct EraseLoadStores {
void eraseBlockLoadStores(Block* block) {
pushFrame(block);
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
auto n = *it;
it++;
switch (n->kind()) {
case prim::Store: {
environment_stack->setVar(n->s(attr::name), n->input());
n->destroy();
} break;
case prim::Load: {
auto name = n->s(attr::name);
auto var = environment_stack->findInAnyFrame(name);
TORCH_INTERNAL_ASSERT(
var, "Typechecking should ensure the variable name is set");
n->output()->replaceAllUsesWith(var);
n->destroy();
} break;
case prim::ComprehensionScope: {
// writes within a local variable scope do not leak into
// the rest of the graph
auto body = n->blocks().at(0);
eraseBlockLoadStores(body);
// inline the local variable scope into the graph
for (auto it_cmpr = body->nodes().begin();
it_cmpr != body->nodes().end();) {
Node* body_node = *it_cmpr;
it_cmpr++;
body_node->moveBefore(n);
}
n->destroy();
} break;
default: {
for (auto b : n->blocks()) {
eraseBlockLoadStores(b);
}
} break;
}
}
popFrame();
}
void pushFrame(Block* b) {
environment_stack =
std::make_shared<ValueEnvironment>(b, environment_stack);
}
std::shared_ptr<ValueEnvironment> popFrame() {
auto old_frame = environment_stack;
environment_stack = environment_stack->next;
return old_frame;
}
void run(std::shared_ptr<Graph>& graph) {
eraseBlockLoadStores(graph->block());
}
std::shared_ptr<ValueEnvironment> environment_stack = nullptr;
};
// This pass transforms Breaks & Continues to be LoopContinuations,
// of the form LoopContinuations(%loop_continue_condition, *loop_carried_vars)
// Break Statements have the condition set to false, and Continue statements
// inline the loop condition as the first input.
struct LoopContinuations {
public:
void run(std::shared_ptr<Graph>& graph) {
run(graph->block());
}
private:
void addLoopCarriedOutputs(Node* n) {
auto g = n->owningGraph();
WithInsertPoint insert(n);
auto continuation = curr_loop_->blocks().at(0)->return_node();
for (auto out : continuation->inputs()) {
auto load_node = out->node();
TORCH_INTERNAL_ASSERT(load_node->kind() == prim::Load);
auto new_load =
g->insertNode(g->createClone(load_node, [](Value* v) { return v; }));
n->addInput(new_load->output());
}
}
void assignExitContinuations(Block* block) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* n = *it;
it++;
switch (n->kind()) {
case prim::If: {
assignExitContinuations(n->blocks().at(0));
assignExitContinuations(n->blocks().at(1));
} break;
case prim::Closure: {
LoopContinuations closure_block;
closure_block.run(n->blocks().at(0));
} break;
case prim::Loop: {
Node* prev_loop = curr_loop_;
curr_loop_ = n;
assignExitContinuations(n->blocks().at(0));
curr_loop_ = prev_loop;
} break;
case prim::ContinueStmt: {
auto loop_continuation =
graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
auto header_block = loop_continuation->addBlock();
auto pre_header = curr_loop_->blocks().at(1);
header_block->cloneFrom(pre_header, [](Value* v) { return v; });
InlineBlockBeforeNode(n, header_block);
loop_continuation->addInput(header_block->outputs().at(0));
loop_continuation->eraseBlock(0);
addLoopCarriedOutputs(loop_continuation);
n->destroy();
} break;
case prim::BreakStmt: {
auto loop_exit =
graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
// first input is the loop continue condition - break sets false
loop_exit->addInput(false_val_);
addLoopCarriedOutputs(loop_exit);
n->destroy();
} break;
}
}
}
void run(Block* b) {
{
graph_ = b->owningGraph();
WithInsertPoint guard(b->nodes().front());
false_val_ = graph_->insertConstant(false);
}
assignExitContinuations(b);
}
Graph* graph_ = nullptr;
Value* false_val_ = nullptr;
Node* curr_loop_ = nullptr;
};
// Converting to SSA works in multiple parts. First, we add control flow
// loads and stores to the graph. Now that control flow outputs are set,
// we can set remove Break & Continue to have the correct continuations to the
// end of the block (LoopContinuation). Then we inline the loop condition into
// the graph. Then, we erase Loads & Stores. Finally, we remove
// LoopContinuations from the graph.
void ConvertToSSA(std::shared_ptr<Graph>& graph) {
ControlFlowLoadStores ctrl;
ctrl.run(graph);
LoopContinuations exit_vars;
exit_vars.run(graph);
InlineLoopCondition(graph);
EraseLoadStores erase_loads_stores;
erase_loads_stores.run(graph);
TransformExits(graph);
}
} // namespace jit
} // namespace torch