lower batchmm to non-diff optimization (#19987)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19987
ghimport-source-id: ca4c38312bd56d8a71f1925297deee7f64f573d3

Differential Revision: D15190356

Pulled By: wanchaol

fbshipit-source-id: 761edb08c670fcbc24a06a5b11ceddf311f75884
This commit is contained in:
Wanchao Liang
2019-05-06 15:36:49 -07:00
committed by Facebook Github Bot
parent 0c5dc965a4
commit 8fbde94664

View File

@ -638,16 +638,17 @@ struct GraphExecutorImpl {
UnrollLoops(graph);
EliminateCommonSubexpression(graph);
// Rewrite subgraphs with many MMs into expressions that batch them.
BatchMM(graph);
CheckInplace(graph);
}
void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
// run custom passes that different backends can register
for (const auto& pass : getCustomPasses()) {
pass(graph);
}
// Rewrite subgraphs with many MMs into expressions that batch them.
BatchMM(graph);
FuseGraph(graph);
}