mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
lower batchmm to non-diff optimization (#19987)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19987 ghimport-source-id: ca4c38312bd56d8a71f1925297deee7f64f573d3 Differential Revision: D15190356 Pulled By: wanchaol fbshipit-source-id: 761edb08c670fcbc24a06a5b11ceddf311f75884
This commit is contained in:
committed by
Facebook Github Bot
parent
0c5dc965a4
commit
8fbde94664
@ -638,16 +638,17 @@ struct GraphExecutorImpl {
|
||||
UnrollLoops(graph);
|
||||
EliminateCommonSubexpression(graph);
|
||||
|
||||
// Rewrite subgraphs with many MMs into expressions that batch them.
|
||||
BatchMM(graph);
|
||||
|
||||
CheckInplace(graph);
|
||||
}
|
||||
|
||||
void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
|
||||
// run custom passes that different backends can register
|
||||
for (const auto& pass : getCustomPasses()) {
|
||||
pass(graph);
|
||||
}
|
||||
// Rewrite subgraphs with many MMs into expressions that batch them.
|
||||
BatchMM(graph);
|
||||
|
||||
FuseGraph(graph);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user