mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Add a boilerplate pass for future TensorExpr fusion pass. (#33464)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33464 I added a python-exposed knob to register this pass in custom passes pipeline. If the knob is not used, the pass is not registered and thus not run at all. Differential Revision: D19958217 Test Plan: Imported from OSS Pulled By: ZolotukhinM fbshipit-source-id: fecdd98567fcda069fbdf8995c796899a3dbfa5c
This commit is contained in:
committed by
Facebook Github Bot
parent
9278196d89
commit
bf00b4d305
@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/tensorexpr_fuser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
|
||||
|
@ -142,6 +142,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/passes/shape_analysis.cpp",
|
||||
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
|
||||
"torch/csrc/jit/passes/subgraph_rewrite.cpp",
|
||||
"torch/csrc/jit/passes/tensorexpr_fuser.cpp",
|
||||
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
|
||||
"torch/csrc/jit/passes/utils/memory_dag.cpp",
|
||||
"torch/csrc/jit/print_handler.cpp",
|
||||
|
@ -42,6 +42,7 @@
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
|
||||
#include <torch/csrc/jit/print_handler.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
@ -322,6 +323,7 @@ void initJITBindings(PyObject* module) {
|
||||
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
|
||||
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
|
||||
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
|
||||
.def("_jit_register_tensorexpr_fuser", ®isterTensorExprFuser)
|
||||
.def(
|
||||
"_jit_differentiate",
|
||||
[](Graph& g) {
|
||||
|
205
torch/csrc/jit/passes/tensorexpr_fuser.cpp
Normal file
205
torch/csrc/jit/passes/tensorexpr_fuser.cpp
Normal file
@ -0,0 +1,205 @@
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/autograd/record_function.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/operator_options.h>
|
||||
#include <torch/csrc/jit/pass_manager.h>
|
||||
#include <torch/csrc/jit/passes/alias_analysis.h>
|
||||
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
const Symbol& getTensorExprSymbol() {
|
||||
static Symbol s = Symbol::fromQualString("tensorexpr::Group");
|
||||
return s;
|
||||
}
|
||||
|
||||
value_list sortReverseTopological(
|
||||
ArrayRef<torch::jit::Value*> inputs,
|
||||
torch::jit::Block* block) {
|
||||
value_list result;
|
||||
for (auto i : inputs) {
|
||||
if (i->node()->owningBlock() == block) {
|
||||
result.push_back(i);
|
||||
}
|
||||
}
|
||||
// Sort in reverse topological order
|
||||
std::sort(
|
||||
result.begin(),
|
||||
result.end(),
|
||||
[&](torch::jit::Value* a, torch::jit::Value* b) {
|
||||
return a->node()->isAfter(b->node());
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
bool canHandle(Node* node, AliasDb& aliasDb) {
|
||||
// TODO: actually support some ops
|
||||
return false;
|
||||
}
|
||||
|
||||
#define REQ(cond) \
|
||||
if (!(cond)) { \
|
||||
GRAPH_DEBUG("Failed cond " #cond "\n"); \
|
||||
return false; \
|
||||
}
|
||||
|
||||
bool canMerge(Node* consumer, Node* producer, AliasDb& aliasDb) {
|
||||
// Only handle complete tensor types
|
||||
for (torch::jit::Value* output : consumer->outputs()) {
|
||||
REQ(output->isCompleteTensor());
|
||||
}
|
||||
|
||||
// Only fuse within a block
|
||||
REQ(consumer->owningBlock() == producer->owningBlock());
|
||||
|
||||
// Symbolic checks
|
||||
REQ(canHandle(producer, aliasDb));
|
||||
REQ(
|
||||
(canHandle(consumer, aliasDb) ||
|
||||
consumer->kind() == getTensorExprSymbol()));
|
||||
|
||||
// Alias checks
|
||||
REQ(aliasDb.couldMoveAfterTopologically(consumer, producer));
|
||||
|
||||
return true;
|
||||
}
|
||||
#undef REQ
|
||||
|
||||
Node* getOrCreateTensorExprSubgraph(Node* n) {
|
||||
if (n->hasAttribute(attr::Subgraph) && n->kind() == getTensorExprSymbol()) {
|
||||
return n;
|
||||
}
|
||||
return SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
|
||||
}
|
||||
|
||||
c10::optional<Node*> tryMerge(
|
||||
Node* consumer,
|
||||
Node* producer,
|
||||
AliasDb& aliasDb) {
|
||||
GRAPH_DEBUG(
|
||||
"Trying producer ",
|
||||
producer->kind().toQualString(),
|
||||
" and consumer ",
|
||||
consumer->kind().toQualString(),
|
||||
":\n");
|
||||
|
||||
if (!canMerge(consumer, producer, aliasDb)) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
consumer = getOrCreateTensorExprSubgraph(consumer);
|
||||
|
||||
aliasDb.moveAfterTopologicallyValid(consumer, producer);
|
||||
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
|
||||
|
||||
return consumer;
|
||||
}
|
||||
|
||||
std::pair<graph_node_list::iterator, bool> scanNode(
|
||||
Node* consumer,
|
||||
AliasDb& aliasDb) {
|
||||
auto inputs =
|
||||
sortReverseTopological(consumer->inputs(), consumer->owningBlock());
|
||||
|
||||
// Grab the iterator below consumer. We'll use that to determine
|
||||
// where to resume iteration, even if consumer gets relocated within
|
||||
// the block.
|
||||
auto iter = --consumer->reverseIterator();
|
||||
for (auto input : inputs) {
|
||||
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
|
||||
// Resume iteration from where consumer is/used to be.
|
||||
return {++iter, true};
|
||||
}
|
||||
}
|
||||
|
||||
// We know consumer didn't move, so skip over it.
|
||||
return {++(++iter), false};
|
||||
}
|
||||
|
||||
Operation createTensorExprOp(const Node* node) {
|
||||
// TODO: actually compile the fusion group.
|
||||
return [](Stack& stack) {
|
||||
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) {
|
||||
auto options = c10::OperatorOptions();
|
||||
options.setAliasAnalysis(k);
|
||||
return options;
|
||||
}
|
||||
|
||||
RegisterOperators TensorExprOps({
|
||||
torch::jit::Operator(
|
||||
getTensorExprSymbol(),
|
||||
createTensorExprOp,
|
||||
getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)),
|
||||
});
|
||||
|
||||
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
|
||||
GRAPH_DUMP("Before TExprFuser: ", graph);
|
||||
|
||||
// Get rid of dead code so that we don't waste effort fusing it.
|
||||
EliminateDeadCode(graph);
|
||||
|
||||
AliasDb aliasDb(graph);
|
||||
auto block = graph->block();
|
||||
|
||||
std::vector<std::pair<graph_node_list_iterator, graph_node_list_iterator>>
|
||||
worklist;
|
||||
std::unordered_set<torch::jit::Block*> visited_blocks;
|
||||
|
||||
bool any_changed = true;
|
||||
while (any_changed) {
|
||||
any_changed = false;
|
||||
worklist.push_back({block->nodes().rbegin(), block->nodes().rend()});
|
||||
|
||||
while (worklist.size()) {
|
||||
auto& it = worklist.back().first;
|
||||
auto end = worklist.back().second;
|
||||
|
||||
if (it->blocks().size()) {
|
||||
Node* n = *it;
|
||||
++it;
|
||||
|
||||
if (it == end) {
|
||||
worklist.pop_back();
|
||||
}
|
||||
|
||||
for (auto b : n->blocks()) {
|
||||
if (!visited_blocks.count(b)) {
|
||||
worklist.push_back({b->nodes().rbegin(), b->nodes().rend()});
|
||||
visited_blocks.insert(b);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bool changed;
|
||||
std::tie(it, changed) = scanNode(*it, aliasDb);
|
||||
any_changed |= changed;
|
||||
if (it == end) {
|
||||
worklist.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EliminateCommonSubexpression(graph);
|
||||
EliminateDeadCode(graph);
|
||||
|
||||
GRAPH_DUMP("After TExprFuser: ", graph);
|
||||
}
|
||||
|
||||
void registerTensorExprFuser() {
|
||||
static bool already_registered = false;
|
||||
if (!already_registered) {
|
||||
RegisterPass pass(fuseTensorExprs);
|
||||
already_registered = true;
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
18
torch/csrc/jit/passes/tensorexpr_fuser.h
Normal file
18
torch/csrc/jit/passes/tensorexpr_fuser.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct Graph;
|
||||
|
||||
// Run TensorExpressions-based fuser.
|
||||
TORCH_API void fuseTensorExprs(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// Register TensorExpressions-based fuser in custom passes.
|
||||
TORCH_API void registerTensorExprFuser();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user