mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[ONNX] Add pass that fuses Conv and BatchNormalization (#40547)
Summary: Add pass that fuses Conv and Batchnormalization nodes into one node Conv. This pass is only applied in inference mode (training is None or TrainingMode.Eval). Since this pass needs access to param_dict it is written outside peephole file where these kind of passes (fusing multiple nodes into one) is usually placed. This PR also adds wrapper skipIfNoEmbed to skip debug_embed_params test: Pass that fuses Conv and Batchnorm changes the params of resnet model and parameters of onnx and pytorch model won't match. Since parameters are not matching, debug_embed_params test for test_resnet will fail and that is expected, therefore debug_embed_params test for test_resnet should be skipped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40547 Reviewed By: gchanan Differential Revision: D22631687 Pulled By: bzinodev fbshipit-source-id: fe45812400398a32541e797f727fd8697eb6d8c0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ad7133d3c1
commit
af5d0bff00
@ -32,6 +32,7 @@
|
||||
#include <torch/csrc/jit/passes/onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
||||
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
|
||||
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
|
||||
#include <torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h>
|
||||
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
|
||||
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
|
||||
@ -143,6 +144,14 @@ void initJITBindings(PyObject* module) {
|
||||
bool fixed_batch_size) {
|
||||
return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_onnx_eval_peephole",
|
||||
[](std::shared_ptr<Graph>& graph,
|
||||
std::map<std::string, IValue>& paramsDict) {
|
||||
EvalPeepholeONNX(graph->block(), paramsDict);
|
||||
return paramsDict;
|
||||
},
|
||||
pybind11::return_value_policy::move)
|
||||
.def(
|
||||
"_jit_pass_onnx_cast_all_constant_to_floating",
|
||||
CastAllConstantToFloating)
|
||||
|
Reference in New Issue
Block a user