Conv BN folding data type issue when conv has no bias (#78241)

PR https://github.com/pytorch/pytorch/pull/77042 has fixed the new folding conv-bn data type issue but missing the case when original conv has no bias input.
In this PR:

- Fix the new folding conv-bn's bias data type issue, when conv has no bias but weight as lower precision datatype, the new generated bias data type should be same as conv's weight.
- Move the Autocast JIT Trace UT from `test_jit.py` to `test_jit_autocast.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78241
Approved by: https://github.com/davidberard98
This commit is contained in:
leslie-fang-intel
2022-05-26 18:42:17 +00:00
committed by PyTorch MergeBot
parent dde56ca329
commit 1a41cd8f97
3 changed files with 70 additions and 48 deletions

View File

@ -9,6 +9,7 @@ from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests
from torch.testing import FileCheck
from jit.test_models import MnistNet
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
@ -750,5 +751,73 @@ class TestAutocast(JitTestCase):
g = torch.jit.last_executed_optimized_graph()
FileCheck().check_not("_autocast_to_reduced").run(g)
class convbn(torch.nn.Module):
def __init__(self, bias_enabled=True):
super(convbn, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
self.bn = torch.nn.BatchNorm2d(64)
def forward(self, x):
return self.bn(self.conv(x))
class TestJitTraceAutocast(JitTestCase):
def setUp(self):
super(TestJitTraceAutocast, self).setUp()
self.previous_default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
self.models = [MnistNet(),
convbn(bias_enabled=True),
convbn(bias_enabled=False)]
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
torch.randn(32, 3, 224, 224, device='cpu'),
torch.randn(32, 3, 224, 224, device='cpu')]
self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
def tearDown(self):
torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
torch.set_default_dtype(self.previous_default_dtype)
super(TestJitTraceAutocast, self).tearDown()
def test_generate_autocast_jit_trace_model(self):
def test_generate_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model)
for i in range(self.models.__len__()):
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nchw_autocast_jit_trace_model(self):
def test_nchw_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model)
with torch.no_grad():
y = traced_model(x.clone())
with torch.cpu.amp.autocast(), torch.no_grad():
y2 = model(x.clone())
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nhwc_autocast_jit_trace_model(self):
def test_nhwc_autocast_jit_trace_model(model, x):
model = model.to(memory_format=torch.channels_last)
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
traced_model = torch.jit.freeze(traced_model)
with torch.no_grad():
y = traced_model(x.clone().to(memory_format=torch.channels_last))
with torch.cpu.amp.autocast(), torch.no_grad():
y2 = model(x.clone().to(memory_format=torch.channels_last))
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
if self.inputs[i].size().__len__() == 5:
# NHWC 3D case not support yet
continue
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
if __name__ == "__main__":
run_tests()