Files
pytorch/torch/csrc/jit/codegen/onednn/prepare_binary.cpp
sanchitintel 4ee29d6033 [Reland take-2] Add JIT graph fuser for oneDNN Graph API (v0.5)
Re-landing #68111/#74596

## Description
v0.5 PR of this [RFC](https://github.com/pytorch/pytorch/issues/49444).

On the basis of #50256, the below improvements are included:

 * The [v0.5 release branch](https://github.com/oneapi-src/oneDNN/releases/tag/graph-v0.5) of the oneDNN Graph API is used
 * The fuser now works with the profiling graph executor. We have inserted type check nodes to guard the profiled tensor properties.

 ### User API:
The optimization pass is disabled by default. Users could enable it by:

```
 torch.jit.enable_onednn_fusion(True)
```
`torch.jit.freeze` should be used after tracing (recommended) or scripting a model.

 ### Performance:
 [pytorch/benchmark](https://github.com/pytorch/benchmark) tool is used to compare the performance:

 * SkyLake 8180 (1 socket of 28 cores):
   ![image](https://user-images.githubusercontent.com/65992142/151162305-05e44425-a24e-4d5e-94e1-743b40b87a8c.png)
* SkyLake 8180 (single thread):
   ![image](https://user-images.githubusercontent.com/65992142/151162528-69f90b79-d08d-46b8-8775-d80a6ccbce8a.png)
   * By mapping hardswish to oneDNN Graph, it’s 8% faster than PyTorch JIT (NNC + OFI)
   ** We expect performance gain after mapping transpose, contiguous & view to oneDNN graph ops

 ### Directory structure of the integration code
 Fuser-related code is placed under:

 ```
 torch/csrc/jit/codegen/onednn/
 ```

 Optimization pass registration is done in:

 ```
 torch/csrc/jit/passes/onednn_graph_fuser.h
 ```

 CMake for the integration code is in:

 ```
 caffe2/CMakeLists.txt
 cmake/public/mkldnn.cmake
 cmake/Modules/FindMKLDNN.cmake
 ```

 ## Limitations
 * In this PR, we only support Pytorch-oneDNN-Graph integration on Linux platform. Support on Windows and MacOS will be enabled as a next step.
 * We have only optimized the inference use-case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76622
Approved by: https://github.com/eellison
2022-05-05 16:57:03 +00:00

107 lines
3.3 KiB
C++

#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
bool compareConstValue(Value* v, double d) {
auto ival = toIValue(v);
return ival.has_value() &&
((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
(ival->isDouble() && ival->toDouble() == d));
}
void mayConvertScalarInputToTensor(Node* node) {
// We do not handle binary ops with two scalar inputs,
// and we assume scalar is always at the second place.
if (node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
(node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
node->input(1)->type()->isSubtypeOf(IntType::get()))) {
auto scalar = node->input(1);
WithInsertPoint guard(node);
auto g = node->owningGraph();
// 42 : Scalar --> tensor(42.0) : Float([])
auto t = g->insert(
aten::as_tensor, {scalar}, {{"dtype", at::ScalarType::Float}});
// add dim & stride info to IR
c10::optional<size_t> t_dim = 1;
auto target_type = TensorTypePtr(
TensorType::create(at::ScalarType::Float, at::kCPU, t_dim, false));
target_type = target_type->withSizes({1});
t->setType(target_type);
// tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
unsqueezed->setType(target_type);
node->replaceInput(1, unsqueezed);
}
}
static void ConvertScalarToTensor(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
ConvertScalarToTensor(sub);
}
if (node->kind() == aten::add || node->kind() == aten::mul) {
mayConvertScalarInputToTensor(node);
}
}
}
void mayDecomposeAdd(Node* node) {
if (toIValue(node->namedInput("alpha")).has_value()) {
auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
if (!alphaEqualsOne) {
WithInsertPoint guard(node);
auto g = node->owningGraph();
auto mul = g->insert(
aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
node->replaceInput(1, mul);
auto one = g->insertConstant(1.0);
node->replaceInput(2, one);
}
}
}
static void DecomposeFusedAdd(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
DecomposeFusedAdd(sub);
}
if (node->kind() == aten::add) {
mayDecomposeAdd(node);
}
}
}
static void EliminateIdentityMulAdd(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
EliminateIdentityMulAdd(sub);
}
if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
(node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
node->output()->replaceAllUsesWith(node->namedInput("self"));
}
}
}
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
DecomposeFusedAdd(graph->block());
EliminateIdentityMulAdd(graph->block());
EliminateDeadCode(graph);
// ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
ConvertScalarToTensor(graph->block());
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch