mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Graph-mode quantization for convolution from traced model (#30245)
Summary: In the PR, we enhance the graph-mode quantization for aten::_convolution, which could be generated from tracing path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/30245 Differential Revision: D18671597 Pulled By: lly-zero-one fbshipit-source-id: 78a2470fbb0fe0def55d63c6bda7cbb5c89f7848
This commit is contained in:
committed by
Facebook Github Bot
parent
2a7a39c1af
commit
59ca9b7430
@ -842,15 +842,17 @@ class GraphModePostTrainingQuantTest(QuantizationTestCase):
|
||||
qconfig_dict = {
|
||||
'': default_qconfig
|
||||
}
|
||||
model_script = quantize_script(
|
||||
torch.jit.script(conv_model_to_script),
|
||||
qconfig_dict,
|
||||
default_eval_fn,
|
||||
[self.img_data],
|
||||
inplace=False)
|
||||
model_traced = torch.jit.trace(conv_model_to_script, self.img_data[0][0])
|
||||
model_script = torch.jit.script(conv_model_to_script)
|
||||
result_eager = model_eager(self.img_data[0][0])
|
||||
result_script = model_script(self.img_data[0][0])
|
||||
self.assertEqual(result_eager, result_script)
|
||||
for model_under_test in [model_traced, model_script]:
|
||||
model_quantized = quantize_script(
|
||||
model_under_test,
|
||||
qconfig_dict,
|
||||
default_eval_fn,
|
||||
[self.img_data],
|
||||
inplace=False)
|
||||
self.assertEqual(model_quantized(self.img_data[0][0]), result_eager)
|
||||
|
||||
@unittest.skip("This doesn't work right now, re-enable after fold_convbn is fixed")
|
||||
def test_conv_bn(self):
|
||||
|
Reference in New Issue
Block a user