[ONNX] Add pass that fuses Conv and BatchNormalization (#40547)

Summary:
Add pass that fuses Conv and Batchnormalization nodes into one node Conv.
This pass is only applied in inference mode (training is None or TrainingMode.Eval).
Since this pass needs access to param_dict it is written outside peephole file where these kind of passes (fusing multiple nodes into one) is usually placed.

This PR also adds wrapper skipIfNoEmbed to skip debug_embed_params test:
Pass that fuses Conv and Batchnorm changes the params of resnet model and parameters of onnx and pytorch model won't match. Since parameters are not matching, debug_embed_params test for test_resnet will fail and that is expected, therefore debug_embed_params test for test_resnet should be skipped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/40547

Reviewed By: gchanan

Differential Revision: D22631687

Pulled By: bzinodev

fbshipit-source-id: fe45812400398a32541e797f727fd8697eb6d8c0
This commit is contained in:
Ksenija Stanojevic
2020-07-22 14:57:23 -07:00
committed by Facebook GitHub Bot
parent ad7133d3c1
commit af5d0bff00
9 changed files with 346 additions and 7 deletions

View File

@ -32,6 +32,7 @@
#include <torch/csrc/jit/passes/onnx.h>
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
@ -143,6 +144,14 @@ void initJITBindings(PyObject* module) {
bool fixed_batch_size) {
return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
})
.def(
"_jit_pass_onnx_eval_peephole",
[](std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& paramsDict) {
EvalPeepholeONNX(graph->block(), paramsDict);
return paramsDict;
},
pybind11::return_value_policy::move)
.def(
"_jit_pass_onnx_cast_all_constant_to_floating",
CastAllConstantToFloating)