mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Conv BN folding data type issue when conv has no bias (#78241)
PR https://github.com/pytorch/pytorch/pull/77042 has fixed the new folding conv-bn data type issue but missing the case when original conv has no bias input. In this PR: - Fix the new folding conv-bn's bias data type issue, when conv has no bias but weight as lower precision datatype, the new generated bias data type should be same as conv's weight. - Move the Autocast JIT Trace UT from `test_jit.py` to `test_jit_autocast.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78241 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
dde56ca329
commit
1a41cd8f97
@ -71,7 +71,6 @@ from jit.test_aten_pow import TestAtenPow # noqa: F401
|
||||
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
|
||||
from jit.test_union import TestUnion # noqa: F401
|
||||
from jit.test_legacy_upgraders import TestLegacyUpgraders # noqa: F401
|
||||
from jit.test_models import MnistNet
|
||||
from jit.test_batch_mm import TestBatchMM # noqa: F401
|
||||
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
|
||||
from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401
|
||||
@ -16197,50 +16196,6 @@ class TestJitGeneratedModule(JitTestCase):
|
||||
class TestJitGeneratedFunctional(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestJitAutocast(JitTestCase):
|
||||
def setUp(self):
|
||||
super(TestJitAutocast, self).setUp()
|
||||
self.models = [MnistNet()]
|
||||
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu')]
|
||||
|
||||
def tearDown(self):
|
||||
super(TestJitAutocast, self).tearDown()
|
||||
|
||||
def test_generate_autocast_jit_trace_model(self):
|
||||
def test_generate_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
for i in range(self.models.__len__()):
|
||||
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
def test_nchw_autocast_jit_trace_model(self):
|
||||
def test_nchw_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
y = traced_model(x.clone())
|
||||
y2 = model(x.clone())
|
||||
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
def test_nhwc_autocast_jit_trace_model(self):
|
||||
def test_nhwc_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
y = traced_model(x.clone().to(memory_format=torch.channels_last))
|
||||
y2 = model(x.clone().to(memory_format=torch.channels_last))
|
||||
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
if self.inputs[i].size().__len__() == 5:
|
||||
# NHWC 3D case not support yet
|
||||
continue
|
||||
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
|
||||
# and we have to disable the failing tests here instead.
|
||||
UBSAN_DISABLED_TESTS = [
|
||||
|
@ -9,6 +9,7 @@ from test_jit import JitTestCase
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing import FileCheck
|
||||
from jit.test_models import MnistNet
|
||||
|
||||
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
||||
|
||||
@ -750,5 +751,73 @@ class TestAutocast(JitTestCase):
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
FileCheck().check_not("_autocast_to_reduced").run(g)
|
||||
|
||||
class convbn(torch.nn.Module):
|
||||
def __init__(self, bias_enabled=True):
|
||||
super(convbn, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
|
||||
self.bn = torch.nn.BatchNorm2d(64)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(self.conv(x))
|
||||
|
||||
class TestJitTraceAutocast(JitTestCase):
|
||||
def setUp(self):
|
||||
super(TestJitTraceAutocast, self).setUp()
|
||||
self.previous_default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.float32)
|
||||
self.models = [MnistNet(),
|
||||
convbn(bias_enabled=True),
|
||||
convbn(bias_enabled=False)]
|
||||
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
|
||||
torch.randn(32, 3, 224, 224, device='cpu'),
|
||||
torch.randn(32, 3, 224, 224, device='cpu')]
|
||||
self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
|
||||
|
||||
def tearDown(self):
|
||||
torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
|
||||
torch.set_default_dtype(self.previous_default_dtype)
|
||||
super(TestJitTraceAutocast, self).tearDown()
|
||||
|
||||
def test_generate_autocast_jit_trace_model(self):
|
||||
def test_generate_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
for i in range(self.models.__len__()):
|
||||
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
def test_nchw_autocast_jit_trace_model(self):
|
||||
def test_nchw_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
with torch.no_grad():
|
||||
y = traced_model(x.clone())
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
y2 = model(x.clone())
|
||||
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
def test_nhwc_autocast_jit_trace_model(self):
|
||||
def test_nhwc_autocast_jit_trace_model(model, x):
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
with torch.no_grad():
|
||||
y = traced_model(x.clone().to(memory_format=torch.channels_last))
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
y2 = model(x.clone().to(memory_format=torch.channels_last))
|
||||
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
if self.inputs[i].size().__len__() == 5:
|
||||
# NHWC 3D case not support yet
|
||||
continue
|
||||
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -87,9 +87,7 @@ bool FoldFrozenConvBatchnorm(Block* b) {
|
||||
// placeholder have the same type as conv_w.
|
||||
at::ScalarType bias_dtype = bn_rm.scalar_type();
|
||||
at::ScalarType weight_dtype = conv_w.scalar_type();
|
||||
at::DeviceType weight_device = conv_w.device().type();
|
||||
if (weight_device == at::kCUDA &&
|
||||
(weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
|
||||
if ((weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
|
||||
bias_dtype == at::kFloat) {
|
||||
bias_dtype = weight_dtype;
|
||||
}
|
||||
|
Reference in New Issue
Block a user