[TensorExpr] Add a boilerplate pass for future TensorExpr fusion pass. (#33464)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33464

I added a python-exposed knob to register this pass in custom passes pipeline. If the knob is not used, the pass is not registered and thus not run at all.

Differential Revision: D19958217

Test Plan: Imported from OSS

Pulled By: ZolotukhinM

fbshipit-source-id: fecdd98567fcda069fbdf8995c796899a3dbfa5c
This commit is contained in:
Mikhail Zolotukhin
2020-02-24 18:45:55 -08:00
committed by Facebook Github Bot
parent 9278196d89
commit bf00b4d305
5 changed files with 227 additions and 0 deletions

View File

@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/tensorexpr_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp

View File

@ -142,6 +142,7 @@ libtorch_sources = [
"torch/csrc/jit/passes/shape_analysis.cpp", "torch/csrc/jit/passes/shape_analysis.cpp",
"torch/csrc/jit/passes/specialize_autogradzero.cpp", "torch/csrc/jit/passes/specialize_autogradzero.cpp",
"torch/csrc/jit/passes/subgraph_rewrite.cpp", "torch/csrc/jit/passes/subgraph_rewrite.cpp",
"torch/csrc/jit/passes/tensorexpr_fuser.cpp",
"torch/csrc/jit/passes/utils/subgraph_utils.cpp", "torch/csrc/jit/passes/utils/subgraph_utils.cpp",
"torch/csrc/jit/passes/utils/memory_dag.cpp", "torch/csrc/jit/passes/utils/memory_dag.cpp",
"torch/csrc/jit/print_handler.cpp", "torch/csrc/jit/print_handler.cpp",

View File

@ -42,6 +42,7 @@
#include <torch/csrc/jit/passes/shape_analysis.h> #include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_autogradzero.h> #include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h> #include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h> #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
#include <torch/csrc/jit/print_handler.h> #include <torch/csrc/jit/print_handler.h>
#include <torch/csrc/jit/pybind_utils.h> #include <torch/csrc/jit/pybind_utils.h>
@ -322,6 +323,7 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero) .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU) .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
.def("_jit_register_tensorexpr_fuser", &registerTensorExprFuser)
.def( .def(
"_jit_differentiate", "_jit_differentiate",
[](Graph& g) { [](Graph& g) {

View File

@ -0,0 +1,205 @@
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/operator_options.h>
#include <torch/csrc/jit/pass_manager.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch {
namespace jit {
const Symbol& getTensorExprSymbol() {
static Symbol s = Symbol::fromQualString("tensorexpr::Group");
return s;
}
value_list sortReverseTopological(
ArrayRef<torch::jit::Value*> inputs,
torch::jit::Block* block) {
value_list result;
for (auto i : inputs) {
if (i->node()->owningBlock() == block) {
result.push_back(i);
}
}
// Sort in reverse topological order
std::sort(
result.begin(),
result.end(),
[&](torch::jit::Value* a, torch::jit::Value* b) {
return a->node()->isAfter(b->node());
});
return result;
}
bool canHandle(Node* node, AliasDb& aliasDb) {
// TODO: actually support some ops
return false;
}
#define REQ(cond) \
if (!(cond)) { \
GRAPH_DEBUG("Failed cond " #cond "\n"); \
return false; \
}
bool canMerge(Node* consumer, Node* producer, AliasDb& aliasDb) {
// Only handle complete tensor types
for (torch::jit::Value* output : consumer->outputs()) {
REQ(output->isCompleteTensor());
}
// Only fuse within a block
REQ(consumer->owningBlock() == producer->owningBlock());
// Symbolic checks
REQ(canHandle(producer, aliasDb));
REQ(
(canHandle(consumer, aliasDb) ||
consumer->kind() == getTensorExprSymbol()));
// Alias checks
REQ(aliasDb.couldMoveAfterTopologically(consumer, producer));
return true;
}
#undef REQ
Node* getOrCreateTensorExprSubgraph(Node* n) {
if (n->hasAttribute(attr::Subgraph) && n->kind() == getTensorExprSymbol()) {
return n;
}
return SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
}
c10::optional<Node*> tryMerge(
Node* consumer,
Node* producer,
AliasDb& aliasDb) {
GRAPH_DEBUG(
"Trying producer ",
producer->kind().toQualString(),
" and consumer ",
consumer->kind().toQualString(),
":\n");
if (!canMerge(consumer, producer, aliasDb)) {
return c10::nullopt;
}
consumer = getOrCreateTensorExprSubgraph(consumer);
aliasDb.moveAfterTopologicallyValid(consumer, producer);
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
return consumer;
}
std::pair<graph_node_list::iterator, bool> scanNode(
Node* consumer,
AliasDb& aliasDb) {
auto inputs =
sortReverseTopological(consumer->inputs(), consumer->owningBlock());
// Grab the iterator below consumer. We'll use that to determine
// where to resume iteration, even if consumer gets relocated within
// the block.
auto iter = --consumer->reverseIterator();
for (auto input : inputs) {
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
// Resume iteration from where consumer is/used to be.
return {++iter, true};
}
}
// We know consumer didn't move, so skip over it.
return {++(++iter), false};
}
Operation createTensorExprOp(const Node* node) {
// TODO: actually compile the fusion group.
return [](Stack& stack) {
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
return 0;
};
}
c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) {
auto options = c10::OperatorOptions();
options.setAliasAnalysis(k);
return options;
}
RegisterOperators TensorExprOps({
torch::jit::Operator(
getTensorExprSymbol(),
createTensorExprOp,
getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)),
});
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("Before TExprFuser: ", graph);
// Get rid of dead code so that we don't waste effort fusing it.
EliminateDeadCode(graph);
AliasDb aliasDb(graph);
auto block = graph->block();
std::vector<std::pair<graph_node_list_iterator, graph_node_list_iterator>>
worklist;
std::unordered_set<torch::jit::Block*> visited_blocks;
bool any_changed = true;
while (any_changed) {
any_changed = false;
worklist.push_back({block->nodes().rbegin(), block->nodes().rend()});
while (worklist.size()) {
auto& it = worklist.back().first;
auto end = worklist.back().second;
if (it->blocks().size()) {
Node* n = *it;
++it;
if (it == end) {
worklist.pop_back();
}
for (auto b : n->blocks()) {
if (!visited_blocks.count(b)) {
worklist.push_back({b->nodes().rbegin(), b->nodes().rend()});
visited_blocks.insert(b);
}
}
} else {
bool changed;
std::tie(it, changed) = scanNode(*it, aliasDb);
any_changed |= changed;
if (it == end) {
worklist.pop_back();
}
}
}
}
EliminateCommonSubexpression(graph);
EliminateDeadCode(graph);
GRAPH_DUMP("After TExprFuser: ", graph);
}
void registerTensorExprFuser() {
static bool already_registered = false;
if (!already_registered) {
RegisterPass pass(fuseTensorExprs);
already_registered = true;
}
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,18 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <memory>
namespace torch {
namespace jit {
struct Graph;
// Run TensorExpressions-based fuser.
TORCH_API void fuseTensorExprs(std::shared_ptr<Graph>& graph);
// Register TensorExpressions-based fuser in custom passes.
TORCH_API void registerTensorExprFuser();
} // namespace jit
} // namespace torch