mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38749 Test Plan: python test/test_quantization.py TestFused Differential Revision: D21654659 Pulled By: supriyar fbshipit-source-id: 301be24083e794f4e71ff1d6d842e1aaefa640f0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7587188037
commit
530d48e93a
@ -47,19 +47,30 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
||||
"""
|
||||
assert(conv.training == bn.training == relu.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
is_3d = isinstance(conv, torch.nn.Conv3d)
|
||||
if conv.training:
|
||||
map_to_fused_module_train = {
|
||||
torch.nn.Conv2d: torch_fused.ConvBnReLU2d,
|
||||
torch.nn.Conv3d: torch_fused.ConvBnReLU3d,
|
||||
}
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
|
||||
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
|
||||
|
||||
return torch_fused.ConvBnReLU3d(conv, bn, relu) if is_3d \
|
||||
else torch_fused.ConvBnReLU2d(conv, bn, relu)
|
||||
fused_module = map_to_fused_module_train.get(type(conv))
|
||||
if fused_module is not None:
|
||||
return fused_module(conv, bn, relu)
|
||||
else:
|
||||
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
|
||||
else:
|
||||
return torch_fused.ConvReLU3d(
|
||||
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu) if is_3d \
|
||||
else torch_fused.ConvReLU2d(
|
||||
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
|
||||
map_to_fused_module_eval = {
|
||||
torch.nn.Conv1d: torch_fused.ConvReLU1d,
|
||||
torch.nn.Conv2d: torch_fused.ConvReLU2d,
|
||||
torch.nn.Conv3d: torch_fused.ConvReLU3d,
|
||||
}
|
||||
fused_module = map_to_fused_module_eval[type(conv)]
|
||||
if fused_module is not None:
|
||||
return fused_module(torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
|
||||
else:
|
||||
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
|
||||
|
||||
# Generalization of getattr
|
||||
def _get_module(model, submodule_key):
|
||||
@ -93,6 +104,8 @@ def fuse_known_modules(mod_list):
|
||||
"""
|
||||
|
||||
OP_LIST_TO_FUSER_METHOD = {
|
||||
(torch.nn.Conv1d, torch.nn.BatchNorm1d): fuse_conv_bn,
|
||||
(torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU): fuse_conv_bn_relu,
|
||||
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
|
||||
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
|
||||
(torch.nn.Conv3d, torch.nn.BatchNorm3d): fuse_conv_bn,
|
||||
|
Reference in New Issue
Block a user