Liveness for BailOut graphs

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21615

Differential Revision: D15793434

Pulled By: Krovatkin

fbshipit-source-id: d89f1bf61ea57a1e3b75f8e2b200c27beb8b46cf
This commit is contained in:
Nikolay Korovaiko
2019-06-12 17:19:26 -07:00
committed by Facebook Github Bot
parent 8c57ce87b0
commit 8dd670657b
7 changed files with 257 additions and 1 deletions

View File

@ -579,6 +579,11 @@
return changed;
}
// Intersect our bitmap with the RHS and return true if ours changed.
bool operator-=(const SparseBitVector &RHS) {
return intersectWithComplement(RHS);
}
// Intersect our bitmap with the RHS and return true if ours changed.
bool operator&=(const SparseBitVector &RHS) {
if (this == &RHS)
@ -867,5 +872,26 @@
return Result;
}
template <unsigned ElementSize>
std::ostream& operator<<(std::ostream& stream, const SparseBitVector<ElementSize>& vec) {
} // end namespace llvm
bool first = true;
stream << "{";
for (auto el : vec)
{
if (first)
{
first = false;
}
else
{
stream << ", ";
}
stream << el;
}
stream << "}";
return stream;
}
} // end namespace c10

View File

@ -402,6 +402,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/graph_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/guard_elimination.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inplace_check.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/liveness.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/loop_unrolling.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/lower_grad_of.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/lower_tuples.cpp

View File

@ -77,6 +77,8 @@ namespace jit {
_(NoneSchemaMatch) \
_(ClassParser) \
_(Profiler) \
_(LivenessIf) \
_(LivenessFor) \
_(InsertAndEliminateGuards) \
_(InsertBailOuts) \
_(PeepholeOptimize) \

View File

@ -26,6 +26,7 @@
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/passes/insert_guards.h"
#include "torch/csrc/jit/passes/liveness.h"
#include "torch/csrc/jit/passes/lower_grad_of.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
@ -908,6 +909,92 @@ static void checkShape(
ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
}
std::vector<std::size_t> values_to_value_ids(
const std::vector<Value*>& values) {
std::vector<std::size_t> result;
for (auto v : values) {
result.push_back(v->unique());
}
return result;
};
void testLivenessIf() {
static const auto basic_example = R"JIT(
def test_if_liveness(x, y, z):
# type: (Tensor, Tensor, bool) -> Tensor
c = torch.empty(2)
a = x + y
if z:
t1 = x * 2
t2 = t1 + 1
c = t1
else:
d1 = y * 3
d2 = d1 + 2
c = d2
return c * a
)JIT";
std::vector<std::size_t> expected_liveness_add_before_if{0, 1, 2, 3, 16, 28};
std::vector<std::size_t> expected_liveness_if{0, 1, 2, 3, 16, 17, 28};
std::vector<std::size_t> expected_first_node_liveness{0, 1, 2};
auto cu = compile(basic_example);
auto& fun = cu->get_function("test_if_liveness");
auto liveness = BuildLivenessSets(fun.graph());
auto nodes = fun.graph()->nodes();
auto if_node = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
return n->kind() == prim::If;
});
auto first_node = *fun.graph()->nodes().begin();
auto actual_first_node_liveness = values_to_value_ids(liveness[first_node]);
ASSERT_EQ(actual_first_node_liveness, expected_first_node_liveness);
auto actual_if_liveness = values_to_value_ids(liveness[*if_node]);
ASSERT_EQ(actual_if_liveness, expected_liveness_if);
auto add = if_node->prev();
auto actual_add_before_if_liveness = values_to_value_ids(liveness[add]);
ASSERT_EQ(actual_add_before_if_liveness, expected_liveness_add_before_if);
}
void testLivenessFor() {
static const auto basic_example = R"JIT(
def test_for_liveness(n, x):
# type: (int, Tensor) -> Tuple[int, Tensor, Tensor]
sum = 0
sum2 = 0
c = x
for i in range(n):
sum += i
j = i + 1
sum2 = sum2 + j + i
c = c * i
return (sum + sum2, c, x)
)JIT";
std::vector<std::size_t> expected_for{0, 1, 2, 7, 15};
std::vector<std::size_t> expected_first_node_liveness{0, 1};
auto cu = compile(basic_example);
auto& fun = cu->get_function("test_for_liveness");
auto liveness = BuildLivenessSets(fun.graph());
auto nodes = fun.graph()->nodes();
auto loop_node = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
return n->kind() == prim::Loop;
});
auto first_node = *fun.graph()->nodes().begin();
auto actual_first_node_liveness = values_to_value_ids(liveness[first_node]);
ASSERT_EQ(actual_first_node_liveness, expected_first_node_liveness);
auto actual_loop_liveness = values_to_value_ids(liveness[*loop_node]);
ASSERT_EQ(actual_loop_liveness, expected_for);
}
void testInsertAndEliminateGuards() {
static const auto basic_example = R"JIT(
def basic(x, y):

View File

@ -91,6 +91,7 @@ libtorch_sources = [
"torch/csrc/jit/passes/inline_forked_closures.cpp",
"torch/csrc/jit/passes/inplace_check.cpp",
"torch/csrc/jit/passes/insert_guards.cpp",
"torch/csrc/jit/passes/liveness.cpp",
"torch/csrc/jit/passes/loop_unrolling.cpp",
"torch/csrc/jit/passes/lower_grad_of.cpp",
"torch/csrc/jit/passes/lower_tuples.cpp",

View File

@ -0,0 +1,115 @@
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/liveness.h>
#include <memory>
namespace torch {
namespace jit {
// LivenessAnalyzer computes "bailout" liveness which is equivalent to
// "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
struct LivenessAnalyzer {
explicit LivenessAnalyzer(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {}
std::unordered_map<Node*, std::vector<Value*>> run() {
processBlock(graph_->block(), SparseBitVector{});
std::unordered_map<Node*, std::vector<Value*>> result;
for (const auto& e : liveness_sets_) {
result.insert({e.first, toValueVector(e.second)});
}
return result;
}
void dump(const std::map<Node*, std::vector<Value*>>& liveness_sets) {
std::cout << "Liveness info:\n";
for (auto e : liveness_sets) {
if (e.first->outputs().size() > 0) {
std::cout << e.first->outputs()[0]->uniqueName();
}
std::cout << " " << e.first->kind().toQualString();
std::cout << " = ";
dump(e.second);
std::cout << std::endl;
}
std::cout << "graph :\n";
graph_->dump();
}
void dump(const std::vector<Value*>& set) {
bool first = true;
std::cout << "[";
for (auto el : set) {
if (first) {
first = false;
} else {
std::cout << ", ";
}
std::cout << el->uniqueName() << "(" << el->unique() << ")";
}
std::cout << "]";
}
private:
SparseBitVector toSparseBitVector(at::ArrayRef<Value*> values) {
SparseBitVector sbv;
for (auto v : values) {
ids_to_values_[v->unique()] = v;
sbv.set(v->unique());
}
return sbv;
}
std::vector<Value*> toValueVector(const SparseBitVector& sbv) {
std::vector<Value*> vec;
for (auto id : sbv) {
vec.push_back(ids_to_values_[id]);
}
return vec;
}
SparseBitVector processBlock(Block* b, SparseBitVector liveness) {
// block outputs are the uses
auto block_outputs = toSparseBitVector(b->outputs());
liveness |= block_outputs;
SparseBitVector defs;
for (Node* it : b->nodes().reverse()) {
// kill outputs
liveness -= toSparseBitVector(it->outputs());
if (it->kind() == prim::Loop) {
auto loop_block = liveness;
// loop's outputs aren't live inside the loop
// loop's block outputs, OTOH, will be considered
// as uses
loop_block = processBlock(it->blocks()[0], loop_block);
// loop block's inputs die outside loop's block
loop_block -= toSparseBitVector(it->blocks()[0]->inputs());
liveness |= loop_block;
} else if (it->kind() == prim::If) {
auto true_liveness = processBlock(it->blocks()[0], liveness);
auto false_liveness = processBlock(it->blocks()[1], liveness);
liveness |= true_liveness;
liveness |= false_liveness;
}
liveness |= toSparseBitVector(it->inputs());
liveness_sets_.insert({it, liveness});
}
return liveness;
}
std::shared_ptr<Graph> graph_;
std::map<Node*, SparseBitVector> liveness_sets_;
std::map<size_t, Value*> ids_to_values_;
};
std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
std::shared_ptr<Graph> graph) {
LivenessAnalyzer la(std::move(graph));
return la.run();
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,24 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ir.h>
#include <unordered_map>
#include <list>
#include <vector>
#include <c10/util/sparse_bitset.h>
namespace torch {
namespace jit {
using ::c10::ProfiledTensorTypePtr;
using SparseBitVector = ::c10::SparseBitVector<256>;
// BuildLivenessSets computes "bailout" liveness which is equivalent to
// "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
TORCH_API std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
std::shared_ptr<Graph> graph);
} // namespace jit
} // namespace torch