mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize boolean expressions & unwraps (#18259)
Summary: Simplify or eliminate boolean and/or expressions, optimize unwrapping a value that cannot be None, and optimize using `is` with a None and a non-None value Since peephole optimize is now introducing constants, i added another constant propagation pass after running it. Previously i had a PR that did this & optimized shape ops - i will add the shape optimizations in a separate PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18259 Differential Revision: D14602749 Pulled By: eellison fbshipit-source-id: 1c3f5a67067d8dfdf55d7b78dcb616472ea8a267
This commit is contained in:
committed by
Facebook Github Bot
parent
a729630cbf
commit
dc6b5b2a52
@ -24,6 +24,7 @@
|
||||
#include <test/cpp/jit/test_ivalue.h>
|
||||
#include <test/cpp/jit/test_misc.h>
|
||||
#include <test/cpp/jit/test_netdef_converter.h>
|
||||
#include <test/cpp/jit/test_peephole_optimize.h>
|
||||
#include <test/cpp/jit/test_subgraph_utils.h>
|
||||
|
||||
using namespace torch::jit::script;
|
||||
@ -61,7 +62,8 @@ namespace jit {
|
||||
_(THNNConv) \
|
||||
_(ATenNativeBatchNorm) \
|
||||
_(NoneSchemaMatch) \
|
||||
_(ClassParser)
|
||||
_(ClassParser) \
|
||||
_(PeepholeOptimize)
|
||||
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
_(ArgumentSpec) \
|
||||
|
104
test/cpp/jit/test_peephole_optimize.h
Normal file
104
test/cpp/jit/test_peephole_optimize.h
Normal file
@ -0,0 +1,104 @@
|
||||
#pragma once
|
||||
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <torch/csrc/jit/irparser.h>
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace script;
|
||||
using namespace testing;
|
||||
|
||||
namespace test {
|
||||
|
||||
void testPeepholeOptimize() {
|
||||
// test is / is not none optimization
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0 : int):
|
||||
%1 : None = prim::Constant()
|
||||
%2 : bool = aten::__is__(%0, %1)
|
||||
%3 : bool = aten::__isnot__(%0, %1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check_not("aten::__is__")
|
||||
->check_not("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0: int?):
|
||||
%1 : None = prim::Constant()
|
||||
%2 : bool = aten::__is__(%0, %1)
|
||||
%3 : bool = aten::__isnot__(%0, %1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check("aten::__is__")
|
||||
->check("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0: int?):
|
||||
%1 : Tensor = prim::AutogradZero()
|
||||
%2 : None = prim::Constant()
|
||||
%4 : bool = aten::__is__(%0, %1)
|
||||
%5 : bool = aten::__isnot__(%1, %2)
|
||||
return (%4, %5)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check("aten::__is__")
|
||||
->check_not("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
// test unwrap optional
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%1 : Float(*, *, *) = prim::Constant()
|
||||
%2 : bool = aten::_unwrap_optional(%1)
|
||||
%3 : bool = prim::unchecked_unwrap_optional(%1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck().check_not("unwrap")->run(*graph);
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%1 : Float(*, *, *)?):
|
||||
%2 : bool = aten::_unwrap_optional(%1)
|
||||
%3 : bool = prim::unchecked_unwrap_optional(%1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck().check_count("unwrap", 2)->run(*graph);
|
||||
}
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1881,6 +1881,26 @@ class TestJit(JitTestCase):
|
||||
# testing that 1 // 0 error is not thrownn
|
||||
self.run_pass('constant_propagation', constant_prop.graph)
|
||||
|
||||
def test_short_circuit_optimization(self):
|
||||
@torch.jit.script
|
||||
def const_expressions(x):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and False, x == 1 or True
|
||||
self.run_pass('constant_propagation', const_expressions.graph)
|
||||
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
|
||||
self.assertEqual(const_expressions(1), (False, True))
|
||||
|
||||
@torch.jit.script
|
||||
def redundant_expressions(x):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and True, x == 1 or False
|
||||
|
||||
self.run_pass('peephole', redundant_expressions.graph)
|
||||
self.assertEqual(redundant_expressions(1), (True, True))
|
||||
self.assertEqual(redundant_expressions(0), (False, False))
|
||||
# and True / or False are removed from graph
|
||||
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
|
||||
|
||||
def test_trace_records_names(self):
|
||||
def foo(bar, baz):
|
||||
baz = bar + 3
|
||||
|
@ -292,7 +292,6 @@ Gradient getGradient(const Node* n) {
|
||||
grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
|
||||
return grad;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
RegisterOperators reg_graph_executor_ops(
|
||||
@ -308,7 +307,6 @@ GraphExecutor* getGradExecutor(Operation& op) {
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// a Graph can be created via tracing, or via a language-based frontend
|
||||
@ -505,6 +503,7 @@ struct GraphExecutorImpl {
|
||||
ConstantPooling(graph);
|
||||
|
||||
PeepholeOptimize(graph);
|
||||
ConstantPropagation(graph);
|
||||
|
||||
// Unroll small loops, and eliminate expressions that are the same at every
|
||||
// iteration.
|
||||
@ -644,6 +643,5 @@ void runRequiredPasses(const std::shared_ptr<Graph>& g) {
|
||||
CanonicalizeOps(g);
|
||||
EliminateDeadCode(g);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -641,6 +641,10 @@ std::shared_ptr<Graph> Graph::copy() {
|
||||
bool Value::mustBeNone() const {
|
||||
return node_->mustBeNone();
|
||||
}
|
||||
bool Value::mustNotBeNone() const {
|
||||
return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
|
||||
!type()->cast<OptionalType>();
|
||||
}
|
||||
|
||||
std::string Value::uniqueNameBase() const {
|
||||
std::string name = uniqueName();
|
||||
@ -771,9 +775,10 @@ bool Node::matches(
|
||||
}
|
||||
|
||||
bool Node::mustBeNone() const {
|
||||
return kind_ == prim::Constant && !this->hasAttributes() &&
|
||||
(output()->type()->cast<OptionalType>() ||
|
||||
output()->type() == NoneType::get());
|
||||
return kind_ == prim::AutogradZero ||
|
||||
(kind_ == prim::Constant && !this->hasAttributes() &&
|
||||
(output()->type()->cast<OptionalType>() ||
|
||||
output()->type() == NoneType::get()));
|
||||
}
|
||||
|
||||
void Node::dump() const {
|
||||
|
@ -171,6 +171,7 @@ struct Value {
|
||||
return type()->kind() == TypeKind::CompleteTensorType;
|
||||
}
|
||||
TORCH_API bool mustBeNone() const;
|
||||
TORCH_API bool mustNotBeNone() const;
|
||||
size_t unique() const {
|
||||
return unique_;
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/csrc/jit/constants.h>
|
||||
#include <torch/csrc/jit/interpreter.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/node_hashing.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/passes/alias_analysis.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
@ -119,22 +120,41 @@ void inlineIf(Node* n, const AliasDb& aliasDb) {
|
||||
inlineIfBody(n->blocks().at(block_index));
|
||||
}
|
||||
|
||||
void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) {
|
||||
n->outputs().at(i)->replaceAllUsesWith(replacement);
|
||||
n->eraseOutput(i);
|
||||
n->blocks().at(0)->eraseOutput(i);
|
||||
n->blocks().at(1)->eraseOutput(i);
|
||||
}
|
||||
|
||||
// remove extra outputs from the node
|
||||
bool removeExtraIfOutputs(Node* n) {
|
||||
AT_CHECK(n->kind() == prim::If, "Only supported for If nodes");
|
||||
auto true_block = n->blocks()[0];
|
||||
auto false_block = n->blocks()[1];
|
||||
auto graph = n->owningGraph();
|
||||
auto initial_outputs = true_block->outputs().size();
|
||||
WithInsertPoint guard(n);
|
||||
for (size_t i = 0; i < true_block->outputs().size();) {
|
||||
auto t_out = true_block->outputs().at(i);
|
||||
auto f_out = false_block->outputs().at(i);
|
||||
|
||||
// neither block changes the output value
|
||||
if (true_block->outputs()[i] == false_block->outputs()[i]) {
|
||||
n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
|
||||
n->eraseOutput(i);
|
||||
true_block->eraseOutput(i);
|
||||
false_block->eraseOutput(i);
|
||||
} else {
|
||||
i++; // increment bc we didn't remove current index
|
||||
replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]);
|
||||
continue;
|
||||
}
|
||||
|
||||
// true block output is constant and constant matches false block output
|
||||
auto maybe_const = toIValue(t_out);
|
||||
auto eq = EqualNode();
|
||||
if (maybe_const && eq(t_out->node(), f_out->node())) {
|
||||
auto new_const = graph->insertConstant(*maybe_const, t_out->type());
|
||||
replaceAndRemoveIfOutput(n, i, new_const);
|
||||
continue;
|
||||
}
|
||||
|
||||
i++; // increment bc we didn't remove current index
|
||||
}
|
||||
// an output was removed
|
||||
return initial_outputs != true_block->outputs().size();
|
||||
@ -213,6 +233,5 @@ void ConstantPropagation(std::shared_ptr<Graph>& graph) {
|
||||
ConstantPropagation(graph->block(), aliasDb);
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
|
||||
#include <torch/csrc/jit/ir_views.h>
|
||||
#include <torch/csrc/jit/symbolic_variable.h>
|
||||
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
@ -165,6 +165,49 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
|
||||
u.user->replaceInput(0, node->inputs().at(0));
|
||||
}
|
||||
}
|
||||
} else if (node->kind() == prim::If) {
|
||||
IfView n(node);
|
||||
// this handles redundant short circuits like "x and True" or "x or False"
|
||||
for (size_t i = 0; i < n.outputs().size(); ++i) {
|
||||
if (n.outputs().at(i)->type() != BoolType::get()) {
|
||||
continue;
|
||||
}
|
||||
bool true_val =
|
||||
constant_as<bool>(n.thenOutputs().at(i)).value_or(false);
|
||||
bool false_val =
|
||||
constant_as<bool>(n.elseOutputs().at(i)).value_or(true);
|
||||
// if an if node's output equals its condition replace output with
|
||||
// condition
|
||||
if (true_val && !false_val) {
|
||||
n.outputs().at(i)->replaceAllUsesWith(n.cond());
|
||||
}
|
||||
}
|
||||
} else if (
|
||||
node->kind() == aten::__is__ || node->kind() == aten::__isnot__) {
|
||||
// if we are comparing a None value with a value that can't be None
|
||||
// replace the output with true if node is __isnot__ or false if node is
|
||||
// __is__
|
||||
AT_ASSERT(node->inputs().size() == 2);
|
||||
for (size_t check_none_index : {0, 1}) {
|
||||
bool input_must_be_none =
|
||||
node->inputs().at(check_none_index)->mustBeNone();
|
||||
bool other_must_not_be_none =
|
||||
node->inputs().at(1 - check_none_index)->mustNotBeNone();
|
||||
if (input_must_be_none && other_must_not_be_none) {
|
||||
WithInsertPoint guard(node);
|
||||
auto output = node->owningGraph()->insertConstant(
|
||||
node->kind() == aten::__isnot__);
|
||||
node->output()->replaceAllUsesWith(output);
|
||||
}
|
||||
}
|
||||
} else if (
|
||||
node->kind() == prim::unchecked_unwrap_optional ||
|
||||
node->kind() == aten::_unwrap_optional) {
|
||||
// we are unwrapping an input that can't be None, remove the unwrap
|
||||
auto input = node->input();
|
||||
if (input->mustNotBeNone()) {
|
||||
node->output()->replaceAllUsesWith(node->input());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -180,6 +223,5 @@ void PeepholeOptimize(
|
||||
bool addmm_fusion_enabled) {
|
||||
PeepholeOptimize(graph->block(), addmm_fusion_enabled);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1039,21 +1039,32 @@ struct to_ir {
|
||||
const auto first_bool_info = findRefinements(first_expr);
|
||||
Value* first_value = emitCond(Expr(first_expr));
|
||||
|
||||
// if the second expr in the short circuit is not evaluated,
|
||||
// than the first expression is False if the short circuit
|
||||
// is an `and` and True if the short circuit is an `or`.
|
||||
// `False and expr` -> False, `True or expr` -> True
|
||||
//
|
||||
// inserting it as a constant makes optimization easier
|
||||
|
||||
Value* first_value_returned;
|
||||
|
||||
const Refinements* first_expr_refinements;
|
||||
const Refinements* second_expr_refinements;
|
||||
// if it's an OR the first expr is emitted in the true branch
|
||||
// and the second expr in the false branch, if it's an AND the opposite
|
||||
if (is_or) {
|
||||
first_value_returned = graph->insertConstant(true, nullptr, loc);
|
||||
first_expr_refinements = &first_bool_info.true_refinements_;
|
||||
second_expr_refinements = &first_bool_info.false_refinements_;
|
||||
} else {
|
||||
first_value_returned = graph->insertConstant(false, nullptr, loc);
|
||||
first_expr_refinements = &first_bool_info.false_refinements_;
|
||||
second_expr_refinements = &first_bool_info.true_refinements_;
|
||||
}
|
||||
|
||||
auto get_first_expr = [&] {
|
||||
insertRefinements(*first_expr_refinements);
|
||||
return first_value;
|
||||
return first_value_returned;
|
||||
};
|
||||
|
||||
auto get_second_expr = [&] {
|
||||
@ -2094,7 +2105,6 @@ struct to_ir {
|
||||
}
|
||||
return classNew->createObject(
|
||||
apply.range(), method, Var(apply.inputs()[0]).name().name());
|
||||
;
|
||||
} else {
|
||||
auto inputs = getNamedValues(apply.inputs(), true);
|
||||
auto attributes = emitAttributes(apply.attributes());
|
||||
|
@ -21,9 +21,14 @@ TypeAndAlias SchemaTypeParser::parseBaseType() {
|
||||
{"float", FloatType::get()},
|
||||
{"int", IntType::get()},
|
||||
{"bool", BoolType::get()},
|
||||
{"None", NoneType::get()},
|
||||
};
|
||||
auto tok = L.expect(TK_IDENT);
|
||||
auto text = tok.text();
|
||||
auto tok = L.cur();
|
||||
if (!L.nextIf(TK_NONE)) {
|
||||
L.expect(TK_IDENT);
|
||||
}
|
||||
std::string text = tok.text();
|
||||
|
||||
auto it = type_map.find(text);
|
||||
if (it == type_map.end()) {
|
||||
if (text.size() > 0 && islower(text[0])) {
|
||||
|
Reference in New Issue
Block a user