mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138976 Approved by: https://github.com/Skylion007
48 lines
1.2 KiB
C++
48 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
namespace torch::jit::fuser::onednn {
|
|
|
|
struct WorkBlock : public std::pair<Node*, Node*> {
|
|
using pair::pair;
|
|
|
|
Node* begin() {
|
|
return this->first;
|
|
}
|
|
Node* end() {
|
|
return this->second;
|
|
}
|
|
};
|
|
|
|
class GraphRewriter {
|
|
public:
|
|
GraphRewriter(Block* block, std::shared_ptr<Graph> graph, AliasDb& aliasDb)
|
|
: block_(block),
|
|
graph_(std::move(graph)),
|
|
aliasDb_(aliasDb),
|
|
llgaHelper_(graph_) {}
|
|
|
|
void cleanupSubgraphs();
|
|
void buildupSubgraphs();
|
|
|
|
private:
|
|
Block* block_;
|
|
std::shared_ptr<Graph> graph_;
|
|
AliasDb& aliasDb_;
|
|
LlgaGraphHelper llgaHelper_;
|
|
std::vector<WorkBlock> buildWorkBlocks();
|
|
std::pair<graph_node_list::iterator, bool> scanNode(
|
|
Node* consumer,
|
|
graph_node_list::iterator workblock_begin);
|
|
std::optional<Node*> tryMerge(Node* consumer, Node* producer);
|
|
};
|
|
|
|
// This pass creates the subgraphs for oneDNN Graph Fusion Nodes.
|
|
// Its code-structure has been vastly inspired from
|
|
// torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
|
|
void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph);
|
|
|
|
} // namespace torch::jit::fuser::onednn
|