mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
ConvBn2d/ConvBnReLU2d (#23357)
Summary: Added _intrinsic.qat.ConvBn2d/_intrinsic.qat.ConvBnReLU2d. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23357 ghstack-source-id: 87519573 Differential Revision: D16295500 fbshipit-source-id: 81e6d1d10d05bf6e343721fc5701d3d6bd7e07e6
This commit is contained in:
committed by
Facebook Github Bot
parent
029c8e7754
commit
6cf9ed4a54
@ -21,6 +21,10 @@ def fuse_conv_bn(conv, bn):
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
|
||||
if conv.training:
|
||||
assert conv.bias is None, 'Only support fusing Conv2d that does not have bias'
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
||||
return torch.nn._intrinsic.ConvBn2d(conv, bn)
|
||||
else:
|
||||
return torch.nn.utils.fuse_conv_bn_eval(conv, bn)
|
||||
@ -42,6 +46,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
|
||||
if conv.training:
|
||||
assert not relu.inplace, 'We only support fusion of non-inplace ReLU.'
|
||||
return torch_fused.ConvBnReLU2d(conv, bn, relu)
|
||||
else:
|
||||
return torch_fused.ConvReLU2d(
|
||||
|
Reference in New Issue
Block a user