#include #include #include #include #include #include namespace torch::jit { // At the beginning of the pass the Graph has already undergone type checking, // and writes or reads to a variable are emitted as Loads and Stores in the // graph. // a = 1 // print(a) // is represented as: // %a.1 : int = prim::Constant[value=1]() // prim::Store[name="a"](%a.1) // %a : int = prim::Load[name="a"]() // prim::Print(%a) // // First, this pass recursively adds the Loads & Stores to control flow nodes // Then the graph is converted to SSA form. using ValueEnvironment = MiniEnvironment; using TypeEnvironment = MiniEnvironment; // Adds Loads & Stores to Loops & Ifs struct ControlFlowLoadStores { static void addBlockInput( Block* b, const TypePtr& type, const std::string& name) { auto g = b->owningGraph(); g->createStore(name, b->addInput(name)->setType(type)) ->insertAfter(b->param_node()); } static void addBlockOutput( Block* exit_block, const TypePtr& type, const std::string& name) { WithInsertPoint insert(exit_block); auto g = exit_block->owningGraph(); auto block_exit = g->insertNode(g->createLoad(name, type))->output(); exit_block->registerOutput(block_exit); } static void addNodeOutput( Node* n, const TypePtr& type, const std::string& name) { auto out = n->addOutput()->setType(type); if (meaningfulName(name)) { out->setDebugName(name); } auto g = n->owningGraph(); g->createStore(name, out)->insertAfter(n); } static void addNodeInput( Node* n, const TypePtr& type, const std::string& name) { auto g = n->owningGraph(); auto inp = g->createLoad(name, type)->insertBefore(n)->output(); n->addInput(inp); } void addIfLoadStores(Node* n) { auto true_block = n->blocks().at(0); auto false_block = n->blocks().at(1); auto true_vars = addControlFlowLoadStores(true_block); auto false_vars = addControlFlowLoadStores(false_block); std::set mutated_variables; for (auto& v : true_vars->definedVariables()) { if (false_vars->findInAnyFrame(v)) { mutated_variables.insert(v); } } for (auto& v : false_vars->definedVariables()) { if (true_vars->findInAnyFrame(v)) { mutated_variables.insert(v); } } // Following the same logic as emitIfElseBlocks in ir_emitter.cpp, // we emit a node output if the variable is defined in each block // and the types of each block can be unified for (const auto& x : mutated_variables) { auto true_type = true_vars->findInAnyFrame(x); auto false_type = false_vars->findInAnyFrame(x); auto unified = unifyTypes(true_type, false_type, /*default_to_union=*/true); addBlockOutput(true_block, true_type, x); addBlockOutput(false_block, false_type, x); addNodeOutput(n, *unified, x); } } // loop_carried_outputs* = Loop(max_trip_count, start_condition, // loop_carried_inputs*) // block0(loop_counter, loop_carried_block*) { // // -> (continue_condition, loop_carried_block_outputs*) // } // all loop_carried_... lists are the same length and represent the value of // loop-carried variables whose definitions are updated as the loop executes // in a way that ensure single static assignment. void addLoopLoadStores(Node* n) { auto body_block = n->blocks().at(0); auto loop_vars = addControlFlowLoadStores(body_block); for (const auto& name : loop_vars->definedVariables()) { // if the variable local to the loop body, then // we do not need a loop carried variable for it auto parent_type = environment_stack->findInAnyFrame(name); if (!parent_type) { continue; } // since the loop may execute 0 or many times, the output types // of the loop and the input loop carried dependencies are conservatively // the union of the output of the body and the input to the loop auto block_type = loop_vars->findInThisFrame(name); auto unified_type = unifyTypes(parent_type, block_type).value(); // Insert a store at the beginning of the loop block, so that all // loads of the variable will use the loop carried value addNodeInput(n, parent_type, name); addBlockInput(body_block, unified_type, name); addBlockOutput(body_block, block_type, name); addNodeOutput(n, unified_type, name); } } std::shared_ptr addControlFlowLoadStores(Block* block) { pushFrame(block); for (Node* n : block->nodes()) { switch (n->kind()) { case prim::If: { addIfLoadStores(n); } break; case prim::Loop: { addLoopLoadStores(n); } break; case prim::Closure: { for (auto b : n->blocks()) { addControlFlowLoadStores(b); } } break; case prim::Store: { environment_stack->setVar(n->s(attr::name), n->input()->type()); } break; case prim::ComprehensionScope: { addControlFlowLoadStores(n->blocks().at(0)); } break; } } return popFrame(); } void pushFrame(Block* b) { environment_stack = std::make_shared(b, environment_stack); } std::shared_ptr popFrame() { auto old_frame = environment_stack; environment_stack = environment_stack->next; return old_frame; } void run(std::shared_ptr& graph) { addControlFlowLoadStores(graph->block()); } std::shared_ptr environment_stack = nullptr; }; // Given a graph where 1) outputs have been added to control flow nodes and // 2) loads and stores are represented in the graph, erase the Loads & Stores. struct EraseLoadStores { void eraseBlockLoadStores(Block* block) { pushFrame(block); for (auto it = block->nodes().begin(); it != block->nodes().end();) { auto n = *it; it++; switch (n->kind()) { case prim::Store: { environment_stack->setVar(n->s(attr::name), n->input()); n->destroy(); } break; case prim::Load: { auto name = n->s(attr::name); auto var = environment_stack->findInAnyFrame(name); TORCH_INTERNAL_ASSERT( var, "Typechecking should ensure the variable name is set"); n->output()->replaceAllUsesWith(var); n->destroy(); } break; case prim::ComprehensionScope: { // writes within a local variable scope do not leak into // the rest of the graph auto body = n->blocks().at(0); eraseBlockLoadStores(body); // inline the local variable scope into the graph for (auto it_cmpr = body->nodes().begin(); it_cmpr != body->nodes().end();) { Node* body_node = *it_cmpr; it_cmpr++; body_node->moveBefore(n); } n->destroy(); } break; default: { for (auto b : n->blocks()) { eraseBlockLoadStores(b); } } break; } } popFrame(); } void pushFrame(Block* b) { environment_stack = std::make_shared(b, environment_stack); } std::shared_ptr popFrame() { auto old_frame = environment_stack; environment_stack = environment_stack->next; return old_frame; } void run(std::shared_ptr& graph) { eraseBlockLoadStores(graph->block()); } std::shared_ptr environment_stack = nullptr; }; // This pass transforms Breaks & Continues to be LoopContinuations, // of the form LoopContinuations(%loop_continue_condition, *loop_carried_vars) // Break Statements have the condition set to false, and Continue statements // inline the loop condition as the first input. struct LoopContinuations { public: void run(std::shared_ptr& graph) { run(graph->block()); } private: void addLoopCarriedOutputs(Node* n) { auto g = n->owningGraph(); WithInsertPoint insert(n); // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) auto continuation = curr_loop_->blocks().at(0)->return_node(); for (auto out : continuation->inputs()) { auto load_node = out->node(); TORCH_INTERNAL_ASSERT(load_node->kind() == prim::Load); auto new_load = g->insertNode(g->createClone(load_node, [](Value* v) { return v; })); n->addInput(new_load->output()); } } void assignExitContinuations(Block* block) { for (auto it = block->nodes().begin(); it != block->nodes().end();) { Node* n = *it; it++; switch (n->kind()) { case prim::If: { assignExitContinuations(n->blocks().at(0)); assignExitContinuations(n->blocks().at(1)); } break; case prim::Closure: { LoopContinuations closure_block; closure_block.run(n->blocks().at(0)); } break; case prim::Loop: { Node* prev_loop = curr_loop_; curr_loop_ = n; assignExitContinuations(n->blocks().at(0)); curr_loop_ = prev_loop; } break; case prim::ContinueStmt: { auto loop_continuation = graph_->create(prim::LoopContinuation, 0)->insertAfter(n); auto header_block = loop_continuation->addBlock(); // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) auto pre_header = curr_loop_->blocks().at(1); header_block->cloneFrom(pre_header, [](Value* v) { return v; }); InlineBlockBeforeNode(n, header_block); loop_continuation->addInput(header_block->outputs().at(0)); loop_continuation->eraseBlock(0); addLoopCarriedOutputs(loop_continuation); n->destroy(); } break; case prim::BreakStmt: { auto loop_exit = graph_->create(prim::LoopContinuation, 0)->insertAfter(n); // first input is the loop continue condition - break sets false loop_exit->addInput(false_val_); addLoopCarriedOutputs(loop_exit); n->destroy(); } break; } } } void run(Block* b) { { graph_ = b->owningGraph(); WithInsertPoint guard(b->nodes().front()); false_val_ = graph_->insertConstant(false); } assignExitContinuations(b); } Graph* graph_ = nullptr; Value* false_val_ = nullptr; Node* curr_loop_ = nullptr; }; // Converting to SSA works in multiple parts. First, we add control flow // loads and stores to the graph. Now that control flow outputs are set, // we can set remove Break & Continue to have the correct continuations to the // end of the block (LoopContinuation). Then we inline the loop condition into // the graph. Then, we erase Loads & Stores. Finally, we remove // LoopContinuations from the graph. void ConvertToSSA(std::shared_ptr& graph) { ControlFlowLoadStores ctrl; ctrl.run(graph); LoopContinuations exit_vars; exit_vars.run(graph); InlineLoopCondition(graph); EraseLoadStores erase_loads_stores; erase_loads_stores.run(graph); TransformExits(graph); } } // namespace torch::jit