Files
pytorch/torch/csrc/jit/codegen/onednn/layout_propagation.cpp
cyy d4a98280a8 [Reland] Use missing-prototypes in torch_cpu (#104138)
This PR enables Wmissing-prototypes in torch_cpu except some generated cpp files and the mps and metal,vulkan backends and caffe2 sources.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104138
Approved by: https://github.com/albanD, https://github.com/malfet
2023-06-26 22:53:43 +00:00

54 lines
1.4 KiB
C++

#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
#include <torch/csrc/jit/jit_log.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
static void LayoutPropagation(Node* n) {
if (!LlgaGraphHelper::isLlgaSubgraph(n))
return;
// initial attr::output_layouts if undefined
if (!n->hasAttribute(attr::output_layouts)) {
const auto num_output = n->outputs().size();
GRAPH_DEBUG("Initial output_layouts of size ", num_output);
std::vector<int64_t> layouts(num_output, STRIDED_LAYOUT);
n->is_(attr::output_layouts, layouts);
}
for (auto input : n->inputs()) {
auto prev = input->node();
auto offset = input->offset();
if (LlgaGraphHelper::isLlgaSubgraph(prev)) {
bool useOpaqueLayout = true;
for (auto& use : input->uses()) {
if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) {
useOpaqueLayout = false;
break;
}
}
if (useOpaqueLayout) {
LlgaNodeWrapper(prev).setOpaqueLayout(offset);
}
}
}
}
static void LayoutPropagation(at::ArrayRef<Block*> blocks) {
for (Block* block : blocks)
for (Node* node : block->nodes())
LayoutPropagation(node);
}
void PropagateLayout(const std::shared_ptr<Graph>& graph) {
LayoutPropagation(graph->block());
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch