Optimize boolean expressions & unwraps (#18259)

Summary:
Simplify or eliminate boolean and/or expressions, optimize unwrapping a value that cannot be None, and optimize using `is` with a None and a non-None value

Since peephole optimize is now introducing constants, i added another constant propagation pass after running it.

Previously i had a PR that did this & optimized shape ops - i will add the shape optimizations in a separate PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18259

Differential Revision: D14602749

Pulled By: eellison

fbshipit-source-id: 1c3f5a67067d8dfdf55d7b78dcb616472ea8a267
This commit is contained in:
eellison
2019-03-25 21:48:11 -07:00
committed by Facebook Github Bot
parent a729630cbf
commit dc6b5b2a52
10 changed files with 226 additions and 20 deletions

View File

@ -24,6 +24,7 @@
#include <test/cpp/jit/test_ivalue.h>
#include <test/cpp/jit/test_misc.h>
#include <test/cpp/jit/test_netdef_converter.h>
#include <test/cpp/jit/test_peephole_optimize.h>
#include <test/cpp/jit/test_subgraph_utils.h>
using namespace torch::jit::script;
@ -61,7 +62,8 @@ namespace jit {
_(THNNConv) \
_(ATenNativeBatchNorm) \
_(NoneSchemaMatch) \
_(ClassParser)
_(ClassParser) \
_(PeepholeOptimize)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \

View File

@ -0,0 +1,104 @@
#pragma once
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/irparser.h>
#include <torch/csrc/jit/passes/peephole.h>
namespace torch {
namespace jit {
using namespace script;
using namespace testing;
namespace test {
void testPeepholeOptimize() {
// test is / is not none optimization
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%0 : int):
%1 : None = prim::Constant()
%2 : bool = aten::__is__(%0, %1)
%3 : bool = aten::__isnot__(%0, %1)
return (%2, %3)
)IR",
graph.get());
PeepholeOptimize(graph);
testing::FileCheck()
.check_not("aten::__is__")
->check_not("aten::__isnot__")
->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%0: int?):
%1 : None = prim::Constant()
%2 : bool = aten::__is__(%0, %1)
%3 : bool = aten::__isnot__(%0, %1)
return (%2, %3)
)IR",
graph.get());
PeepholeOptimize(graph);
testing::FileCheck()
.check("aten::__is__")
->check("aten::__isnot__")
->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%0: int?):
%1 : Tensor = prim::AutogradZero()
%2 : None = prim::Constant()
%4 : bool = aten::__is__(%0, %1)
%5 : bool = aten::__isnot__(%1, %2)
return (%4, %5)
)IR",
graph.get());
PeepholeOptimize(graph);
testing::FileCheck()
.check("aten::__is__")
->check_not("aten::__isnot__")
->run(*graph);
}
// test unwrap optional
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%1 : Float(*, *, *) = prim::Constant()
%2 : bool = aten::_unwrap_optional(%1)
%3 : bool = prim::unchecked_unwrap_optional(%1)
return (%2, %3)
)IR",
graph.get());
PeepholeOptimize(graph);
testing::FileCheck().check_not("unwrap")->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%1 : Float(*, *, *)?):
%2 : bool = aten::_unwrap_optional(%1)
%3 : bool = prim::unchecked_unwrap_optional(%1)
return (%2, %3)
)IR",
graph.get());
PeepholeOptimize(graph);
testing::FileCheck().check_count("unwrap", 2)->run(*graph);
}
}
} // namespace test
} // namespace jit
} // namespace torch

View File

@ -1881,6 +1881,26 @@ class TestJit(JitTestCase):
# testing that 1 // 0 error is not thrownn
self.run_pass('constant_propagation', constant_prop.graph)
def test_short_circuit_optimization(self):
@torch.jit.script
def const_expressions(x):
# type: (int) -> Tuple[bool, bool]
return x == 1 and False, x == 1 or True
self.run_pass('constant_propagation', const_expressions.graph)
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
self.assertEqual(const_expressions(1), (False, True))
@torch.jit.script
def redundant_expressions(x):
# type: (int) -> Tuple[bool, bool]
return x == 1 and True, x == 1 or False
self.run_pass('peephole', redundant_expressions.graph)
self.assertEqual(redundant_expressions(1), (True, True))
self.assertEqual(redundant_expressions(0), (False, False))
# and True / or False are removed from graph
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
def test_trace_records_names(self):
def foo(bar, baz):
baz = bar + 3

View File

@ -292,7 +292,6 @@ Gradient getGradient(const Node* n) {
grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
return grad;
}
} // anonymous namespace
RegisterOperators reg_graph_executor_ops(
@ -308,7 +307,6 @@ GraphExecutor* getGradExecutor(Operation& op) {
}
return nullptr;
}
} // namespace detail
// a Graph can be created via tracing, or via a language-based frontend
@ -505,6 +503,7 @@ struct GraphExecutorImpl {
ConstantPooling(graph);
PeepholeOptimize(graph);
ConstantPropagation(graph);
// Unroll small loops, and eliminate expressions that are the same at every
// iteration.
@ -644,6 +643,5 @@ void runRequiredPasses(const std::shared_ptr<Graph>& g) {
CanonicalizeOps(g);
EliminateDeadCode(g);
}
} // namespace jit
} // namespace torch

View File

@ -641,6 +641,10 @@ std::shared_ptr<Graph> Graph::copy() {
bool Value::mustBeNone() const {
return node_->mustBeNone();
}
bool Value::mustNotBeNone() const {
return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
!type()->cast<OptionalType>();
}
std::string Value::uniqueNameBase() const {
std::string name = uniqueName();
@ -771,9 +775,10 @@ bool Node::matches(
}
bool Node::mustBeNone() const {
return kind_ == prim::Constant && !this->hasAttributes() &&
(output()->type()->cast<OptionalType>() ||
output()->type() == NoneType::get());
return kind_ == prim::AutogradZero ||
(kind_ == prim::Constant && !this->hasAttributes() &&
(output()->type()->cast<OptionalType>() ||
output()->type() == NoneType::get()));
}
void Node::dump() const {

View File

@ -171,6 +171,7 @@ struct Value {
return type()->kind() == TypeKind::CompleteTensorType;
}
TORCH_API bool mustBeNone() const;
TORCH_API bool mustNotBeNone() const;
size_t unique() const {
return unique_;
}

View File

@ -5,6 +5,7 @@
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/node_hashing.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
@ -119,22 +120,41 @@ void inlineIf(Node* n, const AliasDb& aliasDb) {
inlineIfBody(n->blocks().at(block_index));
}
void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) {
n->outputs().at(i)->replaceAllUsesWith(replacement);
n->eraseOutput(i);
n->blocks().at(0)->eraseOutput(i);
n->blocks().at(1)->eraseOutput(i);
}
// remove extra outputs from the node
bool removeExtraIfOutputs(Node* n) {
AT_CHECK(n->kind() == prim::If, "Only supported for If nodes");
auto true_block = n->blocks()[0];
auto false_block = n->blocks()[1];
auto graph = n->owningGraph();
auto initial_outputs = true_block->outputs().size();
WithInsertPoint guard(n);
for (size_t i = 0; i < true_block->outputs().size();) {
auto t_out = true_block->outputs().at(i);
auto f_out = false_block->outputs().at(i);
// neither block changes the output value
if (true_block->outputs()[i] == false_block->outputs()[i]) {
n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
n->eraseOutput(i);
true_block->eraseOutput(i);
false_block->eraseOutput(i);
} else {
i++; // increment bc we didn't remove current index
replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]);
continue;
}
// true block output is constant and constant matches false block output
auto maybe_const = toIValue(t_out);
auto eq = EqualNode();
if (maybe_const && eq(t_out->node(), f_out->node())) {
auto new_const = graph->insertConstant(*maybe_const, t_out->type());
replaceAndRemoveIfOutput(n, i, new_const);
continue;
}
i++; // increment bc we didn't remove current index
}
// an output was removed
return initial_outputs != true_block->outputs().size();
@ -213,6 +233,5 @@ void ConstantPropagation(std::shared_ptr<Graph>& graph) {
ConstantPropagation(graph->block(), aliasDb);
EliminateDeadCode(graph);
}
} // namespace jit
} // namespace torch

View File

@ -1,5 +1,5 @@
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/ir_views.h>
#include <torch/csrc/jit/symbolic_variable.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
@ -165,6 +165,49 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
u.user->replaceInput(0, node->inputs().at(0));
}
}
} else if (node->kind() == prim::If) {
IfView n(node);
// this handles redundant short circuits like "x and True" or "x or False"
for (size_t i = 0; i < n.outputs().size(); ++i) {
if (n.outputs().at(i)->type() != BoolType::get()) {
continue;
}
bool true_val =
constant_as<bool>(n.thenOutputs().at(i)).value_or(false);
bool false_val =
constant_as<bool>(n.elseOutputs().at(i)).value_or(true);
// if an if node's output equals its condition replace output with
// condition
if (true_val && !false_val) {
n.outputs().at(i)->replaceAllUsesWith(n.cond());
}
}
} else if (
node->kind() == aten::__is__ || node->kind() == aten::__isnot__) {
// if we are comparing a None value with a value that can't be None
// replace the output with true if node is __isnot__ or false if node is
// __is__
AT_ASSERT(node->inputs().size() == 2);
for (size_t check_none_index : {0, 1}) {
bool input_must_be_none =
node->inputs().at(check_none_index)->mustBeNone();
bool other_must_not_be_none =
node->inputs().at(1 - check_none_index)->mustNotBeNone();
if (input_must_be_none && other_must_not_be_none) {
WithInsertPoint guard(node);
auto output = node->owningGraph()->insertConstant(
node->kind() == aten::__isnot__);
node->output()->replaceAllUsesWith(output);
}
}
} else if (
node->kind() == prim::unchecked_unwrap_optional ||
node->kind() == aten::_unwrap_optional) {
// we are unwrapping an input that can't be None, remove the unwrap
auto input = node->input();
if (input->mustNotBeNone()) {
node->output()->replaceAllUsesWith(node->input());
}
}
}
}
@ -180,6 +223,5 @@ void PeepholeOptimize(
bool addmm_fusion_enabled) {
PeepholeOptimize(graph->block(), addmm_fusion_enabled);
}
} // namespace jit
} // namespace torch

View File

@ -1039,21 +1039,32 @@ struct to_ir {
const auto first_bool_info = findRefinements(first_expr);
Value* first_value = emitCond(Expr(first_expr));
// if the second expr in the short circuit is not evaluated,
// than the first expression is False if the short circuit
// is an `and` and True if the short circuit is an `or`.
// `False and expr` -> False, `True or expr` -> True
//
// inserting it as a constant makes optimization easier
Value* first_value_returned;
const Refinements* first_expr_refinements;
const Refinements* second_expr_refinements;
// if it's an OR the first expr is emitted in the true branch
// and the second expr in the false branch, if it's an AND the opposite
if (is_or) {
first_value_returned = graph->insertConstant(true, nullptr, loc);
first_expr_refinements = &first_bool_info.true_refinements_;
second_expr_refinements = &first_bool_info.false_refinements_;
} else {
first_value_returned = graph->insertConstant(false, nullptr, loc);
first_expr_refinements = &first_bool_info.false_refinements_;
second_expr_refinements = &first_bool_info.true_refinements_;
}
auto get_first_expr = [&] {
insertRefinements(*first_expr_refinements);
return first_value;
return first_value_returned;
};
auto get_second_expr = [&] {
@ -2094,7 +2105,6 @@ struct to_ir {
}
return classNew->createObject(
apply.range(), method, Var(apply.inputs()[0]).name().name());
;
} else {
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = emitAttributes(apply.attributes());

View File

@ -21,9 +21,14 @@ TypeAndAlias SchemaTypeParser::parseBaseType() {
{"float", FloatType::get()},
{"int", IntType::get()},
{"bool", BoolType::get()},
{"None", NoneType::get()},
};
auto tok = L.expect(TK_IDENT);
auto text = tok.text();
auto tok = L.cur();
if (!L.nextIf(TK_NONE)) {
L.expect(TK_IDENT);
}
std::string text = tok.text();
auto it = type_map.find(text);
if (it == type_map.end()) {
if (text.size() > 0 && islower(text[0])) {