mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138976 Approved by: https://github.com/Skylion007
21 lines
499 B
C++
21 lines
499 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
namespace torch::jit::fuser::onednn {
|
|
|
|
// Prepare binary ops for LLGA
|
|
//
|
|
// The pass does the following:
|
|
//
|
|
// - Convert scalar input of aten::add and aten::mul into Float tensor with
|
|
// dimension [1]
|
|
//
|
|
// - Decompose fused add into aten::mul + aten::add when alpha != 1.0
|
|
//
|
|
// - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1
|
|
//
|
|
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph);
|
|
|
|
} // namespace torch::jit::fuser::onednn
|