ConvBn2d/ConvBnReLU2d (#23357)

Summary:
Added _intrinsic.qat.ConvBn2d/_intrinsic.qat.ConvBnReLU2d.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/23357
ghstack-source-id: 87519573

Differential Revision: D16295500

fbshipit-source-id: 81e6d1d10d05bf6e343721fc5701d3d6bd7e07e6
This commit is contained in:
Jerry Zhang
2019-08-01 10:03:39 -07:00
committed by Facebook Github Bot
parent 029c8e7754
commit 6cf9ed4a54
9 changed files with 412 additions and 60 deletions

View File

@ -21,6 +21,10 @@ def fuse_conv_bn(conv, bn):
"Conv and BN both must be in the same mode (train or eval)."
if conv.training:
assert conv.bias is None, 'Only support fusing Conv2d that does not have bias'
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return torch.nn._intrinsic.ConvBn2d(conv, bn)
else:
return torch.nn.utils.fuse_conv_bn_eval(conv, bn)
@ -42,6 +46,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
"Conv and BN both must be in the same mode (train or eval)."
if conv.training:
assert not relu.inplace, 'We only support fusion of non-inplace ReLU.'
return torch_fused.ConvBnReLU2d(conv, bn, relu)
else:
return torch_fused.ConvReLU2d(