fix conv+bn folding issue for mixed dtype (#99696)

Align the conv+bn folding behavior with jit path for mixed type case: always keep conv's weight and bias dtype after folding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99696
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
XiaobingSuper
2023-04-21 00:02:27 -04:00
committed by PyTorch MergeBot
parent 1fc4d58f43
commit 9b0b31a5e3
3 changed files with 58 additions and 2 deletions

View File

@ -84,6 +84,37 @@ class CPUReproTests(TestCase):
self.assertTrue(conv_seen)
@patch("torch.cuda.is_available", lambda: False)
def test_conv2d_bn_mixed_dtype(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv2d(
3,
16,
kernel_size=3,
stride=1,
padding=1,
bias=False,
dtype=torch.bfloat16,
)
self.bn = torch.nn.BatchNorm2d(
16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
v = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
mod = Model().eval()
with torch.no_grad():
self.common(
mod,
(v,),
)
@unittest.skipIf(not torch._C.has_mkldnn, "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_conv2d_packed(self):

View File

@ -556,6 +556,29 @@ class TestFXExperimental(JitTestCase):
)
self.assertEqual(fused(inp), model(inp))
def test_conv_bn_fusion_mixed_dtype(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16)
self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
model = M().eval()
traced = symbolic_trace(model)
fused = optimization.fuse(traced)
inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
self.assertTrue(
all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
)
self.assertEqual(fused(inp), model(inp))
def test_call_to_assert_no_msg(self):
class M(torch.nn.Module):
def forward(self, a, b):

View File

@ -14,6 +14,8 @@ def fuse_conv_bn_eval(conv, bn, transpose=False):
return fused_conv
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
conv_weight_dtype = conv_w.dtype
conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
@ -27,8 +29,8 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, trans
else:
shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
fused_conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape)
fused_conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(dtype=conv_weight_dtype)
fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(dtype=conv_bias_dtype)
return torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad)