Conv BN folding data type issue when conv has no bias (#78241)

PR https://github.com/pytorch/pytorch/pull/77042 has fixed the new folding conv-bn data type issue but missing the case when original conv has no bias input.
In this PR:

- Fix the new folding conv-bn's bias data type issue, when conv has no bias but weight as lower precision datatype, the new generated bias data type should be same as conv's weight.
- Move the Autocast JIT Trace UT from `test_jit.py` to `test_jit_autocast.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78241
Approved by: https://github.com/davidberard98
This commit is contained in:
leslie-fang-intel
2022-05-26 18:42:17 +00:00
committed by PyTorch MergeBot
parent dde56ca329
commit 1a41cd8f97
3 changed files with 70 additions and 48 deletions

View File

@ -71,7 +71,6 @@ from jit.test_aten_pow import TestAtenPow # noqa: F401
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
from jit.test_union import TestUnion # noqa: F401 from jit.test_union import TestUnion # noqa: F401
from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401 from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401
from jit.test_models import MnistNet
from jit.test_batch_mm import TestBatchMM # noqa: F401 from jit.test_batch_mm import TestBatchMM # noqa: F401
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401 from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401 from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401
@ -16197,50 +16196,6 @@ class TestJitGeneratedModule(JitTestCase):
class TestJitGeneratedFunctional(JitTestCase): class TestJitGeneratedFunctional(JitTestCase):
pass pass
class TestJitAutocast(JitTestCase):
def setUp(self):
super(TestJitAutocast, self).setUp()
self.models = [MnistNet()]
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu')]
def tearDown(self):
super(TestJitAutocast, self).tearDown()
def test_generate_autocast_jit_trace_model(self):
def test_generate_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
for i in range(self.models.__len__()):
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nchw_autocast_jit_trace_model(self):
def test_nchw_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
with torch.cpu.amp.autocast(), torch.no_grad():
y = traced_model(x.clone())
y2 = model(x.clone())
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nhwc_autocast_jit_trace_model(self):
def test_nhwc_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
with torch.cpu.amp.autocast(), torch.no_grad():
y = traced_model(x.clone().to(memory_format=torch.channels_last))
y2 = model(x.clone().to(memory_format=torch.channels_last))
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
if self.inputs[i].size().__len__() == 5:
# NHWC 3D case not support yet
continue
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
# UBSAN per-function exclusions don't seem to work with OpenMP pragmas, # UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
# and we have to disable the failing tests here instead. # and we have to disable the failing tests here instead.
UBSAN_DISABLED_TESTS = [ UBSAN_DISABLED_TESTS = [

View File

@ -9,6 +9,7 @@ from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing import FileCheck from torch.testing import FileCheck
from jit.test_models import MnistNet
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
@ -750,5 +751,73 @@ class TestAutocast(JitTestCase):
g = torch.jit.last_executed_optimized_graph() g = torch.jit.last_executed_optimized_graph()
FileCheck().check_not("_autocast_to_reduced").run(g) FileCheck().check_not("_autocast_to_reduced").run(g)
class convbn(torch.nn.Module):
def __init__(self, bias_enabled=True):
super(convbn, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
self.bn = torch.nn.BatchNorm2d(64)
def forward(self, x):
return self.bn(self.conv(x))
class TestJitTraceAutocast(JitTestCase):
def setUp(self):
super(TestJitTraceAutocast, self).setUp()
self.previous_default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
self.models = [MnistNet(),
convbn(bias_enabled=True),
convbn(bias_enabled=False)]
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
torch.randn(32, 3, 224, 224, device='cpu'),
torch.randn(32, 3, 224, 224, device='cpu')]
self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
def tearDown(self):
torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
torch.set_default_dtype(self.previous_default_dtype)
super(TestJitTraceAutocast, self).tearDown()
def test_generate_autocast_jit_trace_model(self):
def test_generate_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model)
for i in range(self.models.__len__()):
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nchw_autocast_jit_trace_model(self):
def test_nchw_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model)
with torch.no_grad():
y = traced_model(x.clone())
with torch.cpu.amp.autocast(), torch.no_grad():
y2 = model(x.clone())
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nhwc_autocast_jit_trace_model(self):
def test_nhwc_autocast_jit_trace_model(model, x):
model = model.to(memory_format=torch.channels_last)
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
traced_model = torch.jit.freeze(traced_model)
with torch.no_grad():
y = traced_model(x.clone().to(memory_format=torch.channels_last))
with torch.cpu.amp.autocast(), torch.no_grad():
y2 = model(x.clone().to(memory_format=torch.channels_last))
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
if self.inputs[i].size().__len__() == 5:
# NHWC 3D case not support yet
continue
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -87,9 +87,7 @@ bool FoldFrozenConvBatchnorm(Block* b) {
// placeholder have the same type as conv_w. // placeholder have the same type as conv_w.
at::ScalarType bias_dtype = bn_rm.scalar_type(); at::ScalarType bias_dtype = bn_rm.scalar_type();
at::ScalarType weight_dtype = conv_w.scalar_type(); at::ScalarType weight_dtype = conv_w.scalar_type();
at::DeviceType weight_device = conv_w.device().type(); if ((weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
if (weight_device == at::kCUDA &&
(weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
bias_dtype == at::kFloat) { bias_dtype == at::kFloat) {
bias_dtype = weight_dtype; bias_dtype = weight_dtype;
} }