mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65345 FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership. ghstack-source-id: 140044165 Test Plan: CI perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial. Reviewed By: hlu1 Differential Revision: D31027361 fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
394 lines
12 KiB
C++
394 lines
12 KiB
C++
#include <torch/csrc/jit/passes/loop_unrolling.h>
|
|
|
|
#include <ATen/core/interned_strings.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#include <torch/csrc/jit/ir/constants.h>
|
|
#include <torch/csrc/jit/ir/ir_views.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
|
|
static constexpr int64_t kUnrollFactor = 8;
|
|
static constexpr int64_t kMaxBodySize = 32;
|
|
static constexpr int64_t kMaxBodyRepeats = 64;
|
|
|
|
bool isTrueConstant(Value* val) {
|
|
c10::optional<bool> maybe_value = constant_as<bool>(val);
|
|
return maybe_value && *maybe_value;
|
|
}
|
|
|
|
bool isForLoop(Node* node) {
|
|
if (node->kind() != prim::Loop)
|
|
return false;
|
|
Value* start_cond = node->inputs().at(1);
|
|
Value* continue_cond = node->blocks().at(0)->outputs().at(0);
|
|
return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
|
|
}
|
|
|
|
// Counts the size of this block, stopping and returning once reaches limit
|
|
// instructions.
|
|
int64_t limitedBlockSize(Block* body, int64_t limit) {
|
|
auto it = body->nodes().begin();
|
|
auto end = body->nodes().end();
|
|
for (int64_t i = 0; i < limit; ++it) {
|
|
for (Block* subblock : it->blocks()) {
|
|
i += limitedBlockSize(subblock, limit - i);
|
|
}
|
|
if (!it->notExecutedOp()) {
|
|
++i;
|
|
}
|
|
if (it == end) {
|
|
return i;
|
|
}
|
|
}
|
|
return limit;
|
|
}
|
|
|
|
bool isSmallBlock(Block* body) {
|
|
return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
|
|
}
|
|
|
|
// XXX: This function can only be called with a loop that is guaranteed to
|
|
// execute EXACTLY ONCE.
|
|
void inlineBody(Node* loop) {
|
|
auto graph = loop->owningGraph();
|
|
auto body = loop->blocks().at(0);
|
|
WithInsertPoint insert_point_guard{loop};
|
|
|
|
std::unordered_map<Value*, Value*> value_map;
|
|
auto get_value = [&](Value* v) {
|
|
auto it = value_map.find(v);
|
|
if (it != value_map.end())
|
|
return it->second;
|
|
return v;
|
|
};
|
|
|
|
// Loop node has extra (max_iters, initial_cond) inputs,
|
|
// body has an extra (loop_counter) input.
|
|
for (size_t i = 2; i < loop->inputs().size(); ++i) {
|
|
value_map[body->inputs()[i - 1]] = loop->inputs()[i];
|
|
}
|
|
|
|
for (Node* orig : body->nodes()) {
|
|
Node* clone = graph->insertNode(graph->createClone(orig, get_value));
|
|
for (size_t i = 0; i < orig->outputs().size(); ++i) {
|
|
value_map[orig->outputs()[i]] = clone->outputs()[i];
|
|
}
|
|
}
|
|
for (size_t i = 0; i < loop->outputs().size(); ++i) {
|
|
loop->outputs().at(i)->replaceAllUsesWith(
|
|
get_value(body->outputs().at(i + 1)));
|
|
}
|
|
// XXX: it is extremely important to destroy the loop in here. DCE might not
|
|
// be able to conclude that it's safe, because the loop might contain side
|
|
// effects.
|
|
loop->destroy();
|
|
}
|
|
|
|
// inserts a copy of body, passing inputs to the inputs of the block
|
|
// it returns the a list of the Values for the output of the block
|
|
std::vector<Value*> insertBlockCopy(
|
|
Graph& graph,
|
|
Block* body,
|
|
at::ArrayRef<Value*> inputs) {
|
|
TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size());
|
|
std::unordered_map<Value*, Value*> value_map;
|
|
auto get_value = [&](Value* v) {
|
|
auto it = value_map.find(v);
|
|
if (it != value_map.end())
|
|
return it->second;
|
|
return v;
|
|
};
|
|
auto inputs_it = inputs.begin();
|
|
for (Value* input : body->inputs()) {
|
|
value_map[input] = *inputs_it++;
|
|
}
|
|
for (Node* node : body->nodes()) {
|
|
Node* new_node = graph.insertNode(graph.createClone(node, get_value));
|
|
auto outputs_it = new_node->outputs().begin();
|
|
for (Value* output : node->outputs()) {
|
|
value_map[output] = *outputs_it++;
|
|
}
|
|
}
|
|
return fmap(body->outputs(), get_value);
|
|
}
|
|
|
|
void repeatBody(Block* body, size_t times, Block* dest) {
|
|
auto graph = body->owningGraph();
|
|
WithInsertPoint insert_point_guard(dest);
|
|
for (Value* input : body->inputs()) {
|
|
dest->addInput()->copyMetadata(input);
|
|
}
|
|
|
|
std::vector<Value*> io = dest->inputs().vec();
|
|
TORCH_INTERNAL_ASSERT(
|
|
!body->inputs().at(0)->hasUses(), "loop counter should be unused");
|
|
for (const auto i : c10::irange(times)) {
|
|
(void)i; // Suppress unused variable warning
|
|
io[0] = body->inputs().at(0);
|
|
io = insertBlockCopy(*graph, body, io);
|
|
}
|
|
for (Value* output : io) {
|
|
dest->registerOutput(output);
|
|
}
|
|
|
|
// It's likely that we have some dead nodes now - for example the "true"
|
|
// constant that prevents the loop from breaking. We shouldn't wait too long
|
|
// before removing them because they might artificially increase the loop size
|
|
// and prevent outer loop unrolling.
|
|
EliminateDeadCode(dest, false);
|
|
}
|
|
|
|
// Replaces the builtin loop counter with a "mutable" variable outside of the
|
|
// loop.
|
|
void replaceLoopCounter(Node* loop) {
|
|
Graph* graph = loop->owningGraph();
|
|
Block* body = loop->blocks().at(0);
|
|
WithInsertPoint guard(loop);
|
|
Value* init_counter = graph->insertConstant(0);
|
|
|
|
loop->insertInput(2, init_counter);
|
|
loop->insertOutput(0)->setType(IntType::get());
|
|
|
|
Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
|
|
body->inputs()[0]->replaceAllUsesWith(internal_counter);
|
|
|
|
WithInsertPoint insertPointGuard{body->return_node()};
|
|
Value* result = graph->insert(aten::add, {internal_counter, 1});
|
|
body->insertOutput(1, result);
|
|
}
|
|
|
|
void unroll(Node* loop) {
|
|
Graph* graph = loop->owningGraph();
|
|
Block* body = loop->blocks().at(0);
|
|
|
|
// We will be using a "mutable" counter outside of the loop instead of the
|
|
// default one, because this will allow us to share it between the unrolled
|
|
// loop and its epilogue. This is necessary only if the loop counter is
|
|
// actually used in the body.
|
|
if (body->inputs()[0]->uses().size() > 0)
|
|
replaceLoopCounter(loop);
|
|
|
|
// Some optimization for constant-length loops. If we know they won't run too
|
|
// many times, then we can unroll them entirely.
|
|
Value* trip_count = loop->inputs().at(0);
|
|
c10::optional<int64_t> const_len = constant_as<int64_t>(trip_count);
|
|
if (const_len && *const_len < kMaxBodyRepeats) {
|
|
Block* dest = loop->addBlock();
|
|
repeatBody(body, *const_len, dest);
|
|
loop->eraseBlock(0);
|
|
inlineBody(loop);
|
|
return;
|
|
}
|
|
|
|
WithInsertPoint insert_point_guard{loop};
|
|
|
|
// Clone the loop before we unroll it. The clone will become the epilogue.
|
|
Node* loop_epilogue =
|
|
graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop);
|
|
for (size_t i = 0; i < loop->outputs().size(); ++i) {
|
|
loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
|
|
loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
|
|
}
|
|
|
|
Block* dest = loop->addBlock();
|
|
repeatBody(body, kUnrollFactor, dest);
|
|
loop->eraseBlock(0);
|
|
|
|
// Change the iteration counts of both loops
|
|
Value* iter_count = loop->inputs().at(0);
|
|
Value* unrolled_iter_count = graph->insert(
|
|
aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
|
|
loop->replaceInput(0, unrolled_iter_count);
|
|
loop_epilogue->replaceInput(
|
|
0,
|
|
graph->insert(
|
|
aten::sub,
|
|
{iter_count,
|
|
graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
|
|
}
|
|
|
|
bool UnrollLoops(Block* block, bool constant_only) {
|
|
bool changed = false;
|
|
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
|
|
// XXX: unroll might destroy the current node, so we need to pre-increment
|
|
// the iterator
|
|
Node* node = *it;
|
|
++it;
|
|
for (Block* subblock : node->blocks()) {
|
|
changed |= UnrollLoops(subblock, constant_only);
|
|
}
|
|
if (!isForLoop(node)) {
|
|
continue;
|
|
}
|
|
if (constant_only) {
|
|
if (node->inputs().at(0)->node()->kind() != prim::Constant) {
|
|
continue;
|
|
}
|
|
} else if (!isSmallBlock(node->blocks().at(0))) {
|
|
continue;
|
|
}
|
|
|
|
unroll(node);
|
|
changed = true;
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
static void addCondAsOutput(Node* loop) {
|
|
LoopView loop_view(loop);
|
|
loop->addInput(loop_view.inputCond());
|
|
auto block_cond_input = loop_view.bodyBlock()->addInput();
|
|
block_cond_input->copyMetadata(loop_view.inputCond());
|
|
auto cond_output_index =
|
|
loop_view.bodyBlock()->registerOutput(loop_view.nextCond());
|
|
loop_view.bodyBlock()->outputs()[cond_output_index]->copyMetadata(
|
|
loop_view.nextCond());
|
|
auto cond_output = loop->addOutput();
|
|
cond_output->copyMetadata(loop_view.nextCond());
|
|
}
|
|
|
|
bool LoopsPeeler::run(const std::shared_ptr<Graph>& graph) {
|
|
GRAPH_DUMP("Before LoopsPeeler", graph);
|
|
collectLoops(graph->block());
|
|
peelLoops();
|
|
GRAPH_DUMP("After LoopsPeeler", graph);
|
|
return true;
|
|
}
|
|
|
|
void LoopsPeeler::collectLoop(Node* n) {
|
|
if (callback_(n)) {
|
|
if (in_loop_) {
|
|
GRAPH_DEBUG("Loop ", getHeader(in_loop_), " will be unrolled");
|
|
loops_to_peel_.push_back(in_loop_);
|
|
in_loop_ = nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
void LoopsPeeler::collectLoops(Block* block) {
|
|
// we do a pre-order traversal to reduce the number
|
|
// of peeled loops.
|
|
for (auto n : block->nodes()) {
|
|
collectLoop(n);
|
|
}
|
|
collectLoop(block->return_node());
|
|
|
|
// process child blocks
|
|
for (auto n : block->nodes()) {
|
|
auto old_in_loop_ = in_loop_;
|
|
if (n->kind() == prim::Loop) {
|
|
in_loop_ = n;
|
|
}
|
|
for (auto b : n->blocks()) {
|
|
collectLoops(b);
|
|
}
|
|
in_loop_ = old_in_loop_;
|
|
}
|
|
}
|
|
|
|
void LoopsPeeler::peelLoops() {
|
|
for (auto loop : loops_to_peel_) {
|
|
PeelLoop(loop, num_iterations_);
|
|
}
|
|
}
|
|
|
|
bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
|
|
auto peel_predicate = [](Node* n) {
|
|
for (auto i : n->inputs()) {
|
|
if (i->type()->isSubtypeOf(*TensorType::get())) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
};
|
|
|
|
LoopsPeeler lp(peel_predicate);
|
|
return lp.run(graph);
|
|
}
|
|
|
|
Node* PeelLoop(Node* n, size_t times) {
|
|
GRAPH_DEBUG("Peeling the loop ", getHeader(n), " ", times, " times");
|
|
|
|
auto graph = n->owningGraph();
|
|
auto orig_loop = LoopView(n);
|
|
|
|
WithInsertPoint wip(n);
|
|
auto times_const = graph->insertConstant(static_cast<int64_t>(times));
|
|
// N.B. even though a caller may request to peel `times` iterations
|
|
// `maxTripCount` of the original loop might be less than that
|
|
// so we should take the minimum of the two
|
|
auto min_trip_count =
|
|
graph->insert(prim::min, {orig_loop.maxTripCount(), times_const});
|
|
|
|
// make the peeled clone
|
|
auto peeled_copy = graph->createClone(n, [](Value* v) { return v; });
|
|
addCondAsOutput(peeled_copy);
|
|
|
|
LoopView new_lv(peeled_copy);
|
|
graph->insertNode(peeled_copy);
|
|
// only run until the peeled count
|
|
new_lv.replaceMaxTripCount(min_trip_count);
|
|
|
|
// substract `maxTripCount` of the original loop by the number iterations
|
|
// the peeled loop runs
|
|
auto new_max_trip_count =
|
|
graph->insert(aten::sub, {orig_loop.maxTripCount(), min_trip_count});
|
|
orig_loop.replaceMaxTripCount(new_max_trip_count);
|
|
// update the termination condition
|
|
auto cond_index = peeled_copy->outputs().size() - 1;
|
|
orig_loop.replaceInputCondition(peeled_copy->output(cond_index));
|
|
|
|
static const size_t LOOP_DEPS_WITH_COND_OFFSET = 2;
|
|
for (size_t i = 0; i < peeled_copy->outputs().size() -
|
|
1 /* leave off the termination condition */;
|
|
i++) {
|
|
n->replaceInput(LOOP_DEPS_WITH_COND_OFFSET + i, peeled_copy->output(i));
|
|
}
|
|
|
|
// the induction variable also needs to be adjusted by the number of
|
|
// iterations the peeled loop runs
|
|
{
|
|
WithInsertPoint peeled_wip(*orig_loop.bodyBlock()->nodes().begin());
|
|
// we can't create the expression: `new_counter` = `old_counter` + 1 yet
|
|
// because when we
|
|
// run `old_counter->replaceAllUsesWith(new_counter)`, we will get
|
|
// `new_counter = new_counter + 1`
|
|
auto adjusted_iter_counter =
|
|
graph->insert(aten::add, {min_trip_count, min_trip_count});
|
|
orig_loop.currentTripCount()->replaceAllUsesWith(adjusted_iter_counter);
|
|
adjusted_iter_counter->node()->replaceInput(
|
|
0, orig_loop.currentTripCount());
|
|
}
|
|
|
|
return peeled_copy;
|
|
}
|
|
|
|
bool UnrollLoops(std::shared_ptr<Graph>& graph) {
|
|
bool changed = UnrollLoops(graph->block(), false);
|
|
if (changed) {
|
|
EliminateDeadCode(graph);
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
bool UnrollConstantLoops(std::shared_ptr<Graph>& graph) {
|
|
bool changed = UnrollLoops(graph->block(), true);
|
|
if (changed) {
|
|
EliminateDeadCode(graph);
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|