mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[quant] Enable fusion for conv modules with bias (#36173)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36173 Previously we were ignoring the conv bias during training if it existed This PR adds the bias from the conv op during the conv+bn fusion process Test Plan: python test/quantization/test_quantization.py Imported from OSS Differential Revision: D20921613 fbshipit-source-id: eacb2ccf9107f413ac4ef23163ba914af9b90924
This commit is contained in:
committed by
Facebook GitHub Bot
parent
caa45c8e33
commit
6972c27d94
@ -24,7 +24,6 @@ def fuse_conv_bn(conv, bn):
|
||||
is_3d = isinstance(conv, torch.nn.Conv3d)
|
||||
|
||||
if conv.training:
|
||||
assert conv.bias is None, 'Only support fusing Conv2d that does not have bias'
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
||||
@ -50,7 +49,6 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
is_3d = isinstance(conv, torch.nn.Conv3d)
|
||||
if conv.training:
|
||||
assert conv.bias is None, 'Only support fusing Conv that does not have bias'
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
|
||||
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
|
||||
|
Reference in New Issue
Block a user