mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix conv+bn folding issue for mixed dtype (#99696)
Align the conv+bn folding behavior with jit path for mixed type case: always keep conv's weight and bias dtype after folding. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99696 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1fc4d58f43
commit
9b0b31a5e3
@ -84,6 +84,37 @@ class CPUReproTests(TestCase):
|
||||
|
||||
self.assertTrue(conv_seen)
|
||||
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_conv2d_bn_mixed_dtype(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3,
|
||||
16,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
self.bn = torch.nn.BatchNorm2d(
|
||||
16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
v = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
|
||||
mod = Model().eval()
|
||||
with torch.no_grad():
|
||||
self.common(
|
||||
mod,
|
||||
(v,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch._C.has_mkldnn, "MKLDNN is not enabled")
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_conv2d_packed(self):
|
||||
|
@ -556,6 +556,29 @@ class TestFXExperimental(JitTestCase):
|
||||
)
|
||||
self.assertEqual(fused(inp), model(inp))
|
||||
|
||||
def test_conv_bn_fusion_mixed_dtype(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16)
|
||||
self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
model = M().eval()
|
||||
|
||||
traced = symbolic_trace(model)
|
||||
fused = optimization.fuse(traced)
|
||||
inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
|
||||
|
||||
self.assertTrue(
|
||||
all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
|
||||
)
|
||||
self.assertEqual(fused(inp), model(inp))
|
||||
|
||||
def test_call_to_assert_no_msg(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
|
@ -14,6 +14,8 @@ def fuse_conv_bn_eval(conv, bn, transpose=False):
|
||||
return fused_conv
|
||||
|
||||
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
|
||||
conv_weight_dtype = conv_w.dtype
|
||||
conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
|
||||
if conv_b is None:
|
||||
conv_b = torch.zeros_like(bn_rm)
|
||||
if bn_w is None:
|
||||
@ -27,8 +29,8 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, trans
|
||||
else:
|
||||
shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
|
||||
|
||||
fused_conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape)
|
||||
fused_conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
|
||||
fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(dtype=conv_weight_dtype)
|
||||
fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(dtype=conv_bias_dtype)
|
||||
|
||||
return torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad)
|
||||
|
||||
|
Reference in New Issue
Block a user