[quant] Enable fusion for conv modules with bias (#36173)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36173

Previously we were ignoring the conv bias during training if it existed
This PR adds the bias from the conv op during the conv+bn fusion process

Test Plan:
python test/quantization/test_quantization.py

Imported from OSS

Differential Revision: D20921613

fbshipit-source-id: eacb2ccf9107f413ac4ef23163ba914af9b90924
This commit is contained in:
Supriya Rao
2020-04-08 15:51:25 -07:00
committed by Facebook GitHub Bot
parent caa45c8e33
commit 6972c27d94
5 changed files with 76 additions and 18 deletions

View File

@ -24,7 +24,6 @@ def fuse_conv_bn(conv, bn):
is_3d = isinstance(conv, torch.nn.Conv3d)
if conv.training:
assert conv.bias is None, 'Only support fusing Conv2d that does not have bias'
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
@ -50,7 +49,6 @@ def fuse_conv_bn_relu(conv, bn, relu):
"Conv and BN both must be in the same mode (train or eval)."
is_3d = isinstance(conv, torch.nn.Conv3d)
if conv.training:
assert conv.bias is None, 'Only support fusing Conv that does not have bias'
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'