mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138976 Approved by: https://github.com/Skylion007
48 lines
1.3 KiB
C++
48 lines
1.3 KiB
C++
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
|
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
|
|
namespace torch::jit::fuser::onednn {
|
|
|
|
static void LayoutPropagation(Node* n) {
|
|
if (!LlgaGraphHelper::isLlgaSubgraph(n))
|
|
return;
|
|
|
|
// initial attr::output_layouts if undefined
|
|
if (!n->hasAttribute(attr::output_layouts)) {
|
|
const auto num_output = n->outputs().size();
|
|
GRAPH_DEBUG("Initial output_layouts of size ", num_output);
|
|
std::vector<int64_t> layouts(num_output, STRIDED_LAYOUT);
|
|
n->is_(attr::output_layouts, layouts);
|
|
}
|
|
|
|
for (auto input : n->inputs()) {
|
|
auto prev = input->node();
|
|
auto offset = input->offset();
|
|
if (LlgaGraphHelper::isLlgaSubgraph(prev)) {
|
|
bool useOpaqueLayout = true;
|
|
for (auto& use : input->uses()) {
|
|
if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) {
|
|
useOpaqueLayout = false;
|
|
break;
|
|
}
|
|
}
|
|
if (useOpaqueLayout) {
|
|
LlgaNodeWrapper(prev).setOpaqueLayout(offset);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void LayoutPropagation(at::ArrayRef<Block*> blocks) {
|
|
for (Block* block : blocks)
|
|
for (Node* node : block->nodes())
|
|
LayoutPropagation(node);
|
|
}
|
|
|
|
void PropagateLayout(const std::shared_ptr<Graph>& graph) {
|
|
LayoutPropagation(graph->block());
|
|
}
|
|
|
|
} // namespace torch::jit::fuser::onednn
|