mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Liveness for BailOut graphs
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21615 Differential Revision: D15793434 Pulled By: Krovatkin fbshipit-source-id: d89f1bf61ea57a1e3b75f8e2b200c27beb8b46cf
This commit is contained in:
committed by
Facebook Github Bot
parent
8c57ce87b0
commit
8dd670657b
@ -579,6 +579,11 @@
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Intersect our bitmap with the RHS and return true if ours changed.
|
||||
bool operator-=(const SparseBitVector &RHS) {
|
||||
return intersectWithComplement(RHS);
|
||||
}
|
||||
|
||||
// Intersect our bitmap with the RHS and return true if ours changed.
|
||||
bool operator&=(const SparseBitVector &RHS) {
|
||||
if (this == &RHS)
|
||||
@ -867,5 +872,26 @@
|
||||
return Result;
|
||||
}
|
||||
|
||||
template <unsigned ElementSize>
|
||||
std::ostream& operator<<(std::ostream& stream, const SparseBitVector<ElementSize>& vec) {
|
||||
|
||||
} // end namespace llvm
|
||||
bool first = true;
|
||||
stream << "{";
|
||||
for (auto el : vec)
|
||||
{
|
||||
if (first)
|
||||
{
|
||||
first = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
stream << ", ";
|
||||
}
|
||||
stream << el;
|
||||
}
|
||||
stream << "}";
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
} // end namespace c10
|
||||
|
@ -402,6 +402,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/graph_fuser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/guard_elimination.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/inplace_check.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/liveness.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/loop_unrolling.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/lower_grad_of.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/lower_tuples.cpp
|
||||
|
@ -77,6 +77,8 @@ namespace jit {
|
||||
_(NoneSchemaMatch) \
|
||||
_(ClassParser) \
|
||||
_(Profiler) \
|
||||
_(LivenessIf) \
|
||||
_(LivenessFor) \
|
||||
_(InsertAndEliminateGuards) \
|
||||
_(InsertBailOuts) \
|
||||
_(PeepholeOptimize) \
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include "torch/csrc/jit/passes/graph_fuser.h"
|
||||
#include "torch/csrc/jit/passes/guard_elimination.h"
|
||||
#include "torch/csrc/jit/passes/insert_guards.h"
|
||||
#include "torch/csrc/jit/passes/liveness.h"
|
||||
#include "torch/csrc/jit/passes/lower_grad_of.h"
|
||||
#include "torch/csrc/jit/passes/lower_tuples.h"
|
||||
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
|
||||
@ -908,6 +909,92 @@ static void checkShape(
|
||||
ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> values_to_value_ids(
|
||||
const std::vector<Value*>& values) {
|
||||
std::vector<std::size_t> result;
|
||||
for (auto v : values) {
|
||||
result.push_back(v->unique());
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
void testLivenessIf() {
|
||||
static const auto basic_example = R"JIT(
|
||||
def test_if_liveness(x, y, z):
|
||||
# type: (Tensor, Tensor, bool) -> Tensor
|
||||
c = torch.empty(2)
|
||||
a = x + y
|
||||
if z:
|
||||
t1 = x * 2
|
||||
t2 = t1 + 1
|
||||
c = t1
|
||||
else:
|
||||
d1 = y * 3
|
||||
d2 = d1 + 2
|
||||
c = d2
|
||||
|
||||
return c * a
|
||||
)JIT";
|
||||
|
||||
std::vector<std::size_t> expected_liveness_add_before_if{0, 1, 2, 3, 16, 28};
|
||||
std::vector<std::size_t> expected_liveness_if{0, 1, 2, 3, 16, 17, 28};
|
||||
std::vector<std::size_t> expected_first_node_liveness{0, 1, 2};
|
||||
|
||||
auto cu = compile(basic_example);
|
||||
auto& fun = cu->get_function("test_if_liveness");
|
||||
auto liveness = BuildLivenessSets(fun.graph());
|
||||
auto nodes = fun.graph()->nodes();
|
||||
auto if_node = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
|
||||
return n->kind() == prim::If;
|
||||
});
|
||||
|
||||
auto first_node = *fun.graph()->nodes().begin();
|
||||
auto actual_first_node_liveness = values_to_value_ids(liveness[first_node]);
|
||||
ASSERT_EQ(actual_first_node_liveness, expected_first_node_liveness);
|
||||
|
||||
auto actual_if_liveness = values_to_value_ids(liveness[*if_node]);
|
||||
ASSERT_EQ(actual_if_liveness, expected_liveness_if);
|
||||
auto add = if_node->prev();
|
||||
auto actual_add_before_if_liveness = values_to_value_ids(liveness[add]);
|
||||
ASSERT_EQ(actual_add_before_if_liveness, expected_liveness_add_before_if);
|
||||
}
|
||||
|
||||
void testLivenessFor() {
|
||||
static const auto basic_example = R"JIT(
|
||||
def test_for_liveness(n, x):
|
||||
|
||||
# type: (int, Tensor) -> Tuple[int, Tensor, Tensor]
|
||||
sum = 0
|
||||
sum2 = 0
|
||||
c = x
|
||||
for i in range(n):
|
||||
sum += i
|
||||
j = i + 1
|
||||
sum2 = sum2 + j + i
|
||||
c = c * i
|
||||
|
||||
return (sum + sum2, c, x)
|
||||
)JIT";
|
||||
|
||||
std::vector<std::size_t> expected_for{0, 1, 2, 7, 15};
|
||||
std::vector<std::size_t> expected_first_node_liveness{0, 1};
|
||||
|
||||
auto cu = compile(basic_example);
|
||||
auto& fun = cu->get_function("test_for_liveness");
|
||||
auto liveness = BuildLivenessSets(fun.graph());
|
||||
auto nodes = fun.graph()->nodes();
|
||||
auto loop_node = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
|
||||
return n->kind() == prim::Loop;
|
||||
});
|
||||
|
||||
auto first_node = *fun.graph()->nodes().begin();
|
||||
auto actual_first_node_liveness = values_to_value_ids(liveness[first_node]);
|
||||
ASSERT_EQ(actual_first_node_liveness, expected_first_node_liveness);
|
||||
|
||||
auto actual_loop_liveness = values_to_value_ids(liveness[*loop_node]);
|
||||
ASSERT_EQ(actual_loop_liveness, expected_for);
|
||||
}
|
||||
|
||||
void testInsertAndEliminateGuards() {
|
||||
static const auto basic_example = R"JIT(
|
||||
def basic(x, y):
|
||||
|
@ -91,6 +91,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/passes/inline_forked_closures.cpp",
|
||||
"torch/csrc/jit/passes/inplace_check.cpp",
|
||||
"torch/csrc/jit/passes/insert_guards.cpp",
|
||||
"torch/csrc/jit/passes/liveness.cpp",
|
||||
"torch/csrc/jit/passes/loop_unrolling.cpp",
|
||||
"torch/csrc/jit/passes/lower_grad_of.cpp",
|
||||
"torch/csrc/jit/passes/lower_tuples.cpp",
|
||||
|
115
torch/csrc/jit/passes/liveness.cpp
Normal file
115
torch/csrc/jit/passes/liveness.cpp
Normal file
@ -0,0 +1,115 @@
|
||||
#include <torch/csrc/jit/passes/alias_analysis.h>
|
||||
#include <torch/csrc/jit/passes/constant_pooling.h>
|
||||
#include <torch/csrc/jit/passes/liveness.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// LivenessAnalyzer computes "bailout" liveness which is equivalent to
|
||||
// "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
|
||||
struct LivenessAnalyzer {
|
||||
explicit LivenessAnalyzer(std::shared_ptr<Graph> graph)
|
||||
: graph_(std::move(graph)) {}
|
||||
|
||||
std::unordered_map<Node*, std::vector<Value*>> run() {
|
||||
processBlock(graph_->block(), SparseBitVector{});
|
||||
std::unordered_map<Node*, std::vector<Value*>> result;
|
||||
|
||||
for (const auto& e : liveness_sets_) {
|
||||
result.insert({e.first, toValueVector(e.second)});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void dump(const std::map<Node*, std::vector<Value*>>& liveness_sets) {
|
||||
std::cout << "Liveness info:\n";
|
||||
for (auto e : liveness_sets) {
|
||||
if (e.first->outputs().size() > 0) {
|
||||
std::cout << e.first->outputs()[0]->uniqueName();
|
||||
}
|
||||
|
||||
std::cout << " " << e.first->kind().toQualString();
|
||||
std::cout << " = ";
|
||||
dump(e.second);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "graph :\n";
|
||||
graph_->dump();
|
||||
}
|
||||
|
||||
void dump(const std::vector<Value*>& set) {
|
||||
bool first = true;
|
||||
std::cout << "[";
|
||||
for (auto el : set) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
std::cout << ", ";
|
||||
}
|
||||
std::cout << el->uniqueName() << "(" << el->unique() << ")";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
|
||||
private:
|
||||
SparseBitVector toSparseBitVector(at::ArrayRef<Value*> values) {
|
||||
SparseBitVector sbv;
|
||||
for (auto v : values) {
|
||||
ids_to_values_[v->unique()] = v;
|
||||
sbv.set(v->unique());
|
||||
}
|
||||
return sbv;
|
||||
}
|
||||
|
||||
std::vector<Value*> toValueVector(const SparseBitVector& sbv) {
|
||||
std::vector<Value*> vec;
|
||||
for (auto id : sbv) {
|
||||
vec.push_back(ids_to_values_[id]);
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
SparseBitVector processBlock(Block* b, SparseBitVector liveness) {
|
||||
// block outputs are the uses
|
||||
auto block_outputs = toSparseBitVector(b->outputs());
|
||||
liveness |= block_outputs;
|
||||
|
||||
SparseBitVector defs;
|
||||
for (Node* it : b->nodes().reverse()) {
|
||||
// kill outputs
|
||||
liveness -= toSparseBitVector(it->outputs());
|
||||
if (it->kind() == prim::Loop) {
|
||||
auto loop_block = liveness;
|
||||
// loop's outputs aren't live inside the loop
|
||||
// loop's block outputs, OTOH, will be considered
|
||||
// as uses
|
||||
loop_block = processBlock(it->blocks()[0], loop_block);
|
||||
// loop block's inputs die outside loop's block
|
||||
loop_block -= toSparseBitVector(it->blocks()[0]->inputs());
|
||||
liveness |= loop_block;
|
||||
} else if (it->kind() == prim::If) {
|
||||
auto true_liveness = processBlock(it->blocks()[0], liveness);
|
||||
auto false_liveness = processBlock(it->blocks()[1], liveness);
|
||||
liveness |= true_liveness;
|
||||
liveness |= false_liveness;
|
||||
}
|
||||
liveness |= toSparseBitVector(it->inputs());
|
||||
liveness_sets_.insert({it, liveness});
|
||||
}
|
||||
return liveness;
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::map<Node*, SparseBitVector> liveness_sets_;
|
||||
std::map<size_t, Value*> ids_to_values_;
|
||||
};
|
||||
|
||||
std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
|
||||
std::shared_ptr<Graph> graph) {
|
||||
LivenessAnalyzer la(std::move(graph));
|
||||
return la.run();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
24
torch/csrc/jit/passes/liveness.h
Normal file
24
torch/csrc/jit/passes/liveness.h
Normal file
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <unordered_map>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <c10/util/sparse_bitset.h>
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using ::c10::ProfiledTensorTypePtr;
|
||||
using SparseBitVector = ::c10::SparseBitVector<256>;
|
||||
|
||||
// BuildLivenessSets computes "bailout" liveness which is equivalent to
|
||||
// "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
|
||||
TORCH_API std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
|
||||
std::shared_ptr<Graph> graph);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user