mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add Post Freezing Optimizations, turn on by default in torch.jit.freeze (#50222)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50222 This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. Test Plan: Imported from OSS Reviewed By: tugsbayasgalan Differential Revision: D25856264 Pulled By: eellison fbshipit-source-id: 56be1f12cfc459b4c4421d4dfdedff8b9ac77112
This commit is contained in:
committed by
Facebook GitHub Bot
parent
30aeed7c2b
commit
a389b30bfc
@ -23,6 +23,7 @@
|
||||
#include <torch/csrc/jit/passes/fold_conv_bn.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/fuse_relu.h>
|
||||
#include <torch/csrc/jit/passes/graph_fuser.h>
|
||||
@ -299,6 +300,7 @@ void initJITBindings(PyObject* module) {
|
||||
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
|
||||
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
|
||||
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
|
||||
.def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
|
||||
.def("_jit_pass_fuse_linear", &FuseLinear)
|
||||
.def(
|
||||
"_jit_pass_fuse_add_relu",
|
||||
|
Reference in New Issue
Block a user