Add Post Freezing Optimizations, turn on by default in torch.jit.freeze (#50222)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50222

This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal.

I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything.

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations.

Test Plan: Imported from OSS

Reviewed By: tugsbayasgalan

Differential Revision: D25856264

Pulled By: eellison

fbshipit-source-id: 56be1f12cfc459b4c4421d4dfdedff8b9ac77112
This commit is contained in:
Elias Ellison
2021-01-12 11:35:08 -08:00
committed by Facebook GitHub Bot
parent 30aeed7c2b
commit a389b30bfc
8 changed files with 110 additions and 8 deletions

View File

@ -23,6 +23,7 @@
#include <torch/csrc/jit/passes/fold_conv_bn.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/fuse_relu.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
@ -299,6 +300,7 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
.def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
.def("_jit_pass_fuse_linear", &FuseLinear)
.def(
"_jit_pass_fuse_add_relu",