mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138976 Approved by: https://github.com/Skylion007
26 lines
1.1 KiB
C++
26 lines
1.1 KiB
C++
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
|
|
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
|
|
|
namespace torch::jit::fuser::onednn {
|
|
|
|
void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph) {
|
|
AliasDb db(graph);
|
|
GraphRewriter graphRewriter(graph->block(), graph, db);
|
|
// We maintain alias db correctness in-place while building up the LLGA
|
|
// subgraphs, however it is difficult to preserve correctness when
|
|
// un-inlining autodiff subgraphs. We first recursively construct all
|
|
// subgraphs and then recursively cleanup & unmerge the small subgraphs
|
|
graphRewriter.buildupSubgraphs();
|
|
graphRewriter.cleanupSubgraphs();
|
|
// Run CSE globally onceto eliminate duplicates that may have occurred
|
|
// while inlining subgraphs.
|
|
EliminateCommonSubexpression(graph);
|
|
EliminateDeadCode(graph);
|
|
}
|
|
|
|
} // namespace torch::jit::fuser::onednn
|