[quant] Support for fused ConvBn1d and ConvBnRelu1d modules (#38452) (#38749)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38749

Test Plan: python test/test_quantization.py TestFused

Differential Revision: D21654659

Pulled By: supriyar

fbshipit-source-id: 301be24083e794f4e71ff1d6d842e1aaefa640f0
This commit is contained in:
Supriya Rao
2020-05-19 22:46:28 -07:00
committed by Facebook GitHub Bot
parent 7587188037
commit 530d48e93a
8 changed files with 75 additions and 17 deletions

View File

@ -47,19 +47,30 @@ def fuse_conv_bn_relu(conv, bn, relu):
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
is_3d = isinstance(conv, torch.nn.Conv3d)
if conv.training:
map_to_fused_module_train = {
torch.nn.Conv2d: torch_fused.ConvBnReLU2d,
torch.nn.Conv3d: torch_fused.ConvBnReLU3d,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
return torch_fused.ConvBnReLU3d(conv, bn, relu) if is_3d \
else torch_fused.ConvBnReLU2d(conv, bn, relu)
fused_module = map_to_fused_module_train.get(type(conv))
if fused_module is not None:
return fused_module(conv, bn, relu)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
return torch_fused.ConvReLU3d(
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu) if is_3d \
else torch_fused.ConvReLU2d(
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
map_to_fused_module_eval = {
torch.nn.Conv1d: torch_fused.ConvReLU1d,
torch.nn.Conv2d: torch_fused.ConvReLU2d,
torch.nn.Conv3d: torch_fused.ConvReLU3d,
}
fused_module = map_to_fused_module_eval[type(conv)]
if fused_module is not None:
return fused_module(torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
# Generalization of getattr
def _get_module(model, submodule_key):
@ -93,6 +104,8 @@ def fuse_known_modules(mod_list):
"""
OP_LIST_TO_FUSER_METHOD = {
(torch.nn.Conv1d, torch.nn.BatchNorm1d): fuse_conv_bn,
(torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU): fuse_conv_bn_relu,
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
(torch.nn.Conv3d, torch.nn.BatchNorm3d): fuse_conv_bn,