mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138976 Approved by: https://github.com/Skylion007
40 lines
1.4 KiB
C++
40 lines
1.4 KiB
C++
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
|
|
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
|
|
namespace torch::jit::fuser::onednn {
|
|
|
|
//! [ Note -- prepareFusionGroupAndGuardOutputs implementation ]
|
|
//! shamelessly copying code from NNC (tensorexpr_fuser) with very little
|
|
//! modification, original code at:
|
|
//! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs`
|
|
//!
|
|
//! We have the assumption that LLGA does not have operators
|
|
//! depending on the content of the tensor.
|
|
void prepareFusionGroupAndGuardOutputs(Block* block) {
|
|
std::vector<Node*> fusion_groups;
|
|
for (Node* n : block->nodes()) {
|
|
for (Block* b : n->blocks()) {
|
|
prepareFusionGroupAndGuardOutputs(b);
|
|
}
|
|
if (n->kind() == prim::oneDNNFusionGroup) {
|
|
fusion_groups.push_back(n);
|
|
}
|
|
}
|
|
for (Node* fusion_group : fusion_groups) {
|
|
// TODO: add further optimization pass to removeOutputsUsedOnlyInSize,
|
|
// refer to
|
|
// `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize`
|
|
// removeOutputsUsedOnlyInSize(fusion_group);
|
|
insertTypeGuard(
|
|
fusion_group,
|
|
[](const TensorTypePtr& t) { return t; },
|
|
prim::oneDNNFusionGuard);
|
|
}
|
|
}
|
|
|
|
} // namespace torch::jit::fuser::onednn
|