mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Add a boilerplate pass for future TensorExpr fusion pass. (#33464)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33464 I added a python-exposed knob to register this pass in custom passes pipeline. If the knob is not used, the pass is not registered and thus not run at all. Differential Revision: D19958217 Test Plan: Imported from OSS Pulled By: ZolotukhinM fbshipit-source-id: fecdd98567fcda069fbdf8995c796899a3dbfa5c
This commit is contained in:
committed by
Facebook Github Bot
parent
9278196d89
commit
bf00b4d305
@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||||||
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/passes/tensorexpr_fuser.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
|
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
|
||||||
|
@ -142,6 +142,7 @@ libtorch_sources = [
|
|||||||
"torch/csrc/jit/passes/shape_analysis.cpp",
|
"torch/csrc/jit/passes/shape_analysis.cpp",
|
||||||
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
|
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
|
||||||
"torch/csrc/jit/passes/subgraph_rewrite.cpp",
|
"torch/csrc/jit/passes/subgraph_rewrite.cpp",
|
||||||
|
"torch/csrc/jit/passes/tensorexpr_fuser.cpp",
|
||||||
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
|
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
|
||||||
"torch/csrc/jit/passes/utils/memory_dag.cpp",
|
"torch/csrc/jit/passes/utils/memory_dag.cpp",
|
||||||
"torch/csrc/jit/print_handler.cpp",
|
"torch/csrc/jit/print_handler.cpp",
|
||||||
|
@ -42,6 +42,7 @@
|
|||||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||||
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
||||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||||
|
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||||
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
|
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
|
||||||
#include <torch/csrc/jit/print_handler.h>
|
#include <torch/csrc/jit/print_handler.h>
|
||||||
#include <torch/csrc/jit/pybind_utils.h>
|
#include <torch/csrc/jit/pybind_utils.h>
|
||||||
@ -322,6 +323,7 @@ void initJITBindings(PyObject* module) {
|
|||||||
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
|
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
|
||||||
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
|
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
|
||||||
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
|
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
|
||||||
|
.def("_jit_register_tensorexpr_fuser", ®isterTensorExprFuser)
|
||||||
.def(
|
.def(
|
||||||
"_jit_differentiate",
|
"_jit_differentiate",
|
||||||
[](Graph& g) {
|
[](Graph& g) {
|
||||||
|
205
torch/csrc/jit/passes/tensorexpr_fuser.cpp
Normal file
205
torch/csrc/jit/passes/tensorexpr_fuser.cpp
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||||
|
#include <torch/csrc/autograd/record_function.h>
|
||||||
|
#include <torch/csrc/jit/custom_operator.h>
|
||||||
|
#include <torch/csrc/jit/jit_log.h>
|
||||||
|
#include <torch/csrc/jit/operator_options.h>
|
||||||
|
#include <torch/csrc/jit/pass_manager.h>
|
||||||
|
#include <torch/csrc/jit/passes/alias_analysis.h>
|
||||||
|
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
||||||
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||||
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
const Symbol& getTensorExprSymbol() {
|
||||||
|
static Symbol s = Symbol::fromQualString("tensorexpr::Group");
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
value_list sortReverseTopological(
|
||||||
|
ArrayRef<torch::jit::Value*> inputs,
|
||||||
|
torch::jit::Block* block) {
|
||||||
|
value_list result;
|
||||||
|
for (auto i : inputs) {
|
||||||
|
if (i->node()->owningBlock() == block) {
|
||||||
|
result.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Sort in reverse topological order
|
||||||
|
std::sort(
|
||||||
|
result.begin(),
|
||||||
|
result.end(),
|
||||||
|
[&](torch::jit::Value* a, torch::jit::Value* b) {
|
||||||
|
return a->node()->isAfter(b->node());
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool canHandle(Node* node, AliasDb& aliasDb) {
|
||||||
|
// TODO: actually support some ops
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REQ(cond) \
|
||||||
|
if (!(cond)) { \
|
||||||
|
GRAPH_DEBUG("Failed cond " #cond "\n"); \
|
||||||
|
return false; \
|
||||||
|
}
|
||||||
|
|
||||||
|
bool canMerge(Node* consumer, Node* producer, AliasDb& aliasDb) {
|
||||||
|
// Only handle complete tensor types
|
||||||
|
for (torch::jit::Value* output : consumer->outputs()) {
|
||||||
|
REQ(output->isCompleteTensor());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only fuse within a block
|
||||||
|
REQ(consumer->owningBlock() == producer->owningBlock());
|
||||||
|
|
||||||
|
// Symbolic checks
|
||||||
|
REQ(canHandle(producer, aliasDb));
|
||||||
|
REQ(
|
||||||
|
(canHandle(consumer, aliasDb) ||
|
||||||
|
consumer->kind() == getTensorExprSymbol()));
|
||||||
|
|
||||||
|
// Alias checks
|
||||||
|
REQ(aliasDb.couldMoveAfterTopologically(consumer, producer));
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#undef REQ
|
||||||
|
|
||||||
|
Node* getOrCreateTensorExprSubgraph(Node* n) {
|
||||||
|
if (n->hasAttribute(attr::Subgraph) && n->kind() == getTensorExprSymbol()) {
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
return SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::optional<Node*> tryMerge(
|
||||||
|
Node* consumer,
|
||||||
|
Node* producer,
|
||||||
|
AliasDb& aliasDb) {
|
||||||
|
GRAPH_DEBUG(
|
||||||
|
"Trying producer ",
|
||||||
|
producer->kind().toQualString(),
|
||||||
|
" and consumer ",
|
||||||
|
consumer->kind().toQualString(),
|
||||||
|
":\n");
|
||||||
|
|
||||||
|
if (!canMerge(consumer, producer, aliasDb)) {
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
consumer = getOrCreateTensorExprSubgraph(consumer);
|
||||||
|
|
||||||
|
aliasDb.moveAfterTopologicallyValid(consumer, producer);
|
||||||
|
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
|
||||||
|
|
||||||
|
return consumer;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<graph_node_list::iterator, bool> scanNode(
|
||||||
|
Node* consumer,
|
||||||
|
AliasDb& aliasDb) {
|
||||||
|
auto inputs =
|
||||||
|
sortReverseTopological(consumer->inputs(), consumer->owningBlock());
|
||||||
|
|
||||||
|
// Grab the iterator below consumer. We'll use that to determine
|
||||||
|
// where to resume iteration, even if consumer gets relocated within
|
||||||
|
// the block.
|
||||||
|
auto iter = --consumer->reverseIterator();
|
||||||
|
for (auto input : inputs) {
|
||||||
|
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
|
||||||
|
// Resume iteration from where consumer is/used to be.
|
||||||
|
return {++iter, true};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We know consumer didn't move, so skip over it.
|
||||||
|
return {++(++iter), false};
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation createTensorExprOp(const Node* node) {
|
||||||
|
// TODO: actually compile the fusion group.
|
||||||
|
return [](Stack& stack) {
|
||||||
|
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) {
|
||||||
|
auto options = c10::OperatorOptions();
|
||||||
|
options.setAliasAnalysis(k);
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
RegisterOperators TensorExprOps({
|
||||||
|
torch::jit::Operator(
|
||||||
|
getTensorExprSymbol(),
|
||||||
|
createTensorExprOp,
|
||||||
|
getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)),
|
||||||
|
});
|
||||||
|
|
||||||
|
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
|
||||||
|
GRAPH_DUMP("Before TExprFuser: ", graph);
|
||||||
|
|
||||||
|
// Get rid of dead code so that we don't waste effort fusing it.
|
||||||
|
EliminateDeadCode(graph);
|
||||||
|
|
||||||
|
AliasDb aliasDb(graph);
|
||||||
|
auto block = graph->block();
|
||||||
|
|
||||||
|
std::vector<std::pair<graph_node_list_iterator, graph_node_list_iterator>>
|
||||||
|
worklist;
|
||||||
|
std::unordered_set<torch::jit::Block*> visited_blocks;
|
||||||
|
|
||||||
|
bool any_changed = true;
|
||||||
|
while (any_changed) {
|
||||||
|
any_changed = false;
|
||||||
|
worklist.push_back({block->nodes().rbegin(), block->nodes().rend()});
|
||||||
|
|
||||||
|
while (worklist.size()) {
|
||||||
|
auto& it = worklist.back().first;
|
||||||
|
auto end = worklist.back().second;
|
||||||
|
|
||||||
|
if (it->blocks().size()) {
|
||||||
|
Node* n = *it;
|
||||||
|
++it;
|
||||||
|
|
||||||
|
if (it == end) {
|
||||||
|
worklist.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto b : n->blocks()) {
|
||||||
|
if (!visited_blocks.count(b)) {
|
||||||
|
worklist.push_back({b->nodes().rbegin(), b->nodes().rend()});
|
||||||
|
visited_blocks.insert(b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bool changed;
|
||||||
|
std::tie(it, changed) = scanNode(*it, aliasDb);
|
||||||
|
any_changed |= changed;
|
||||||
|
if (it == end) {
|
||||||
|
worklist.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EliminateCommonSubexpression(graph);
|
||||||
|
EliminateDeadCode(graph);
|
||||||
|
|
||||||
|
GRAPH_DUMP("After TExprFuser: ", graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerTensorExprFuser() {
|
||||||
|
static bool already_registered = false;
|
||||||
|
if (!already_registered) {
|
||||||
|
RegisterPass pass(fuseTensorExprs);
|
||||||
|
already_registered = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
18
torch/csrc/jit/passes/tensorexpr_fuser.h
Normal file
18
torch/csrc/jit/passes/tensorexpr_fuser.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
struct Graph;
|
||||||
|
|
||||||
|
// Run TensorExpressions-based fuser.
|
||||||
|
TORCH_API void fuseTensorExprs(std::shared_ptr<Graph>& graph);
|
||||||
|
|
||||||
|
// Register TensorExpressions-based fuser in custom passes.
|
||||||
|
TORCH_API void registerTensorExprFuser();
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
Reference in New Issue
Block a user