mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32331 Test Plan: Imported from OSS Differential Revision: D19441158 Pulled By: jamesr66a fbshipit-source-id: c04247ffe707be68718c486c31bc6c6040f7dc11
1451 lines
62 KiB
Python
1451 lines
62 KiB
Python
import unittest
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.intrinsic as nni
|
|
import torch.nn.intrinsic.quantized as nniq
|
|
import torch.nn.intrinsic.qat as nniqat
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
from torch.quantization import \
|
|
get_observer_dict, default_weight_observer, \
|
|
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
|
|
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
|
|
default_dynamic_qconfig, per_channel_dynamic_qconfig, HistogramObserver, MinMaxObserver, \
|
|
PerChannelMinMaxObserver, RecordingObserver, MovingAverageMinMaxObserver, \
|
|
MovingAveragePerChannelMinMaxObserver, QuantWrapper, default_eval_fn, \
|
|
float16_dynamic_qconfig
|
|
|
|
from torch.quantization import QConfig
|
|
from torch.quantization import default_histogram_observer
|
|
from torch.quantization import default_observer
|
|
from torch.quantization import default_per_channel_weight_observer
|
|
from torch.quantization import default_per_channel_qconfig
|
|
from torch.quantization._quantize_script import quantize_script
|
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.common_quantization import QuantizationTestCase, \
|
|
AnnotatedSingleLayerLinearModel, SingleLayerLinearModel, \
|
|
AnnotatedConvModel, ConvModel, \
|
|
AnnotatedConvBnModel, ConvBnModel, \
|
|
SkipQuantModel, QuantStubModel, \
|
|
ModelForFusion, ModelWithSequentialFusion, ManualLinearQATModel, ManualConvLinearQATModel, \
|
|
ModelWithFunctionals, \
|
|
test_only_eval_fn, test_only_train_fn, \
|
|
prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \
|
|
TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel, \
|
|
ModelWithNoQconfigPropagation
|
|
|
|
from torch.testing._internal.common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
|
|
AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel
|
|
|
|
from hypothesis import given
|
|
from hypothesis import strategies as st
|
|
import torch.testing._internal.hypothesis_utils as hu
|
|
hu.assert_deadline_disabled()
|
|
import io
|
|
import copy
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
class EagerModePostTrainingQuantTest(QuantizationTestCase):
|
|
@given(qconfig=st.sampled_from((torch.quantization.default_qconfig, torch.quantization.default_per_channel_qconfig)))
|
|
def test_single_layer(self, qconfig):
|
|
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
|
|
to nnq.Linear which is the quantized version of the module
|
|
"""
|
|
model = AnnotatedSingleLayerLinearModel()
|
|
model.qconfig = qconfig
|
|
model = prepare(model)
|
|
# Check if observers and quant/dequant nodes are inserted
|
|
self.checkNoPrepModules(model)
|
|
self.checkHasPrepModules(model.fc1)
|
|
self.checkObservers(model)
|
|
|
|
test_only_eval_fn(model, self.calib_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkNoPrepModules(model)
|
|
self.checkHasPrepModules(model.fc1)
|
|
self.checkWrappedQuantizedLinear(model.fc1)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API - out of place version
|
|
base = AnnotatedSingleLayerLinearModel()
|
|
base.qconfig = qconfig
|
|
keys_before = set(list(base.state_dict().keys()))
|
|
model = quantize(base, test_only_eval_fn, self.calib_data)
|
|
checkQuantized(model)
|
|
keys_after = set(list(base.state_dict().keys()))
|
|
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
|
|
|
|
# in-place version
|
|
model = AnnotatedSingleLayerLinearModel()
|
|
model.qconfig = qconfig
|
|
quantize(model, test_only_eval_fn, self.calib_data, inplace=True)
|
|
checkQuantized(model)
|
|
|
|
def test_two_layers(self):
|
|
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
|
|
`fc2`, and `fc1`is not quantized
|
|
"""
|
|
model = AnnotatedTwoLayerLinearModel()
|
|
model = prepare(model)
|
|
|
|
self.checkNoPrepModules(model)
|
|
self.checkObservers(model)
|
|
self.checkNoPrepModules(model.fc1)
|
|
self.checkHasPrepModules(model.fc2)
|
|
|
|
test_only_eval_fn(model, self.calib_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkNoPrepModules(model)
|
|
self.checkNoPrepModules(model.fc1)
|
|
self.checkHasPrepModules(model.fc2)
|
|
self.assertEqual(type(model.fc1), torch.nn.Linear)
|
|
self.checkWrappedQuantizedLinear(model.fc2)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn,
|
|
self.calib_data)
|
|
checkQuantized(model)
|
|
|
|
def test_nested1(self):
|
|
r"""Test quantization for nested model, top level 'fc3' and
|
|
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
|
|
"""
|
|
model = AnnotatedNestedModel()
|
|
|
|
def checkPrepModules(model, before_calib=False):
|
|
if before_calib:
|
|
self.checkObservers(model)
|
|
self.checkNoPrepModules(model)
|
|
self.checkNoPrepModules(model.sub1)
|
|
self.checkNoPrepModules(model.sub1.fc)
|
|
self.checkNoPrepModules(model.sub1.relu)
|
|
self.checkNoPrepModules(model.sub2)
|
|
self.checkHasPrepModules(model.sub2.fc1)
|
|
self.checkNoPrepModules(model.sub2.fc2)
|
|
self.checkHasPrepModules(model.fc3)
|
|
|
|
model = prepare(model)
|
|
checkPrepModules(model, True)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
checkPrepModules(model)
|
|
self.checkLinear(model.sub1.fc)
|
|
self.checkWrappedQuantizedLinear(model.fc3)
|
|
self.checkWrappedQuantizedLinear(model.sub2.fc1)
|
|
self.checkLinear(model.sub2.fc2)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize(AnnotatedNestedModel(), test_only_eval_fn,
|
|
self.calib_data)
|
|
checkQuantized(model)
|
|
|
|
|
|
def test_nested2(self):
|
|
model = AnnotatedSubNestedModel()
|
|
model = prepare(model)
|
|
|
|
def checkPrepModules(model, before_calib=False):
|
|
if before_calib:
|
|
self.checkObservers(model)
|
|
self.checkNoPrepModules(model)
|
|
self.checkNoPrepModules(model.sub1)
|
|
self.checkNoPrepModules(model.sub1.fc)
|
|
self.checkNoPrepModules(model.sub1.relu)
|
|
self.checkHasPrepModules(model.sub2)
|
|
self.checkNoPrepModules(model.sub2.module.fc1)
|
|
self.checkNoPrepModules(model.sub2.module.fc2)
|
|
self.checkHasPrepModules(model.fc3)
|
|
|
|
checkPrepModules(model, True)
|
|
|
|
test_only_eval_fn(model, self.calib_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
checkPrepModules(model)
|
|
self.checkLinear(model.sub1.fc)
|
|
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
|
|
self.checkQuantizedLinear(model.sub2.module.fc1)
|
|
self.checkQuantizedLinear(model.sub2.module.fc2)
|
|
self.checkWrappedQuantizedLinear(model.fc3)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn,
|
|
self.calib_data)
|
|
checkQuantized(model)
|
|
|
|
def test_nested3(self):
|
|
r"""More complicated nested test case with child qconfig overrides
|
|
parent qconfig
|
|
"""
|
|
model = AnnotatedCustomConfigNestedModel()
|
|
model = prepare(model)
|
|
|
|
def checkPrepModules(model, before_calib=False):
|
|
if before_calib:
|
|
self.checkObservers(model)
|
|
self.checkNoPrepModules(model)
|
|
self.checkNoPrepModules(model.sub1)
|
|
self.checkNoPrepModules(model.sub1.fc)
|
|
self.checkNoPrepModules(model.sub1.relu)
|
|
self.checkNoPrepModules(model.sub2)
|
|
self.checkHasPrepModules(model.sub2.fc1)
|
|
self.checkHasPrepModules(model.sub2.fc2)
|
|
self.checkHasPrepModules(model.fc3)
|
|
|
|
checkPrepModules(model, True)
|
|
|
|
test_only_eval_fn(model, self.calib_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
checkPrepModules(model)
|
|
self.checkWrappedQuantizedLinear(model.sub2.fc1)
|
|
self.checkWrappedQuantizedLinear(model.sub2.fc2)
|
|
self.checkWrappedQuantizedLinear(model.fc3)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn,
|
|
self.calib_data)
|
|
checkQuantized(model)
|
|
|
|
def test_skip_quant(self):
|
|
r"""The case when we want to skip quantizing some layers
|
|
"""
|
|
|
|
model = SkipQuantModel()
|
|
model = prepare(model)
|
|
self.checkObservers(model)
|
|
|
|
test_only_eval_fn(model, self.calib_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkLinear(model.fc)
|
|
self.checkQuantDequant(model.sub)
|
|
self.checkQuantizedLinear(model.sub.module.fc1)
|
|
self.checkQuantizedLinear(model.sub.module.fc2)
|
|
self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize(SkipQuantModel(), test_only_eval_fn, self.calib_data)
|
|
checkQuantized(model)
|
|
|
|
|
|
def test_manual(self):
|
|
r"""User inserts QuantStub and DeQuantStub in model code
|
|
and call the quantization utility functions.
|
|
"""
|
|
model = QuantStubModel()
|
|
# propagate the qconfig of parents to children, model is changed
|
|
# inplace
|
|
model = prepare(model)
|
|
self.checkObservers(model)
|
|
|
|
test_only_eval_fn(model, self.calib_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.fc), nnq.Linear)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data)
|
|
checkQuantized(model)
|
|
|
|
@given(qconfig=st.sampled_from((torch.quantization.default_qconfig, torch.quantization.default_per_channel_qconfig)))
|
|
def test_resnet_base(self, qconfig):
|
|
r"""Test quantization for bottleneck topology used in resnet/resnext
|
|
and add coverage for conversion of average pool and float functional
|
|
"""
|
|
model = ResNetBase().float().eval()
|
|
model = QuantWrapper(model)
|
|
model.qconfig = qconfig
|
|
fuse_list = ['module.conv1', 'module.bn1', 'module.relu1']
|
|
fuse_modules(model, fuse_list, inplace=True)
|
|
model = prepare(model)
|
|
self.checkObservers(model)
|
|
test_only_eval_fn(model, self.img_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d)
|
|
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
|
|
self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
|
|
test_only_eval_fn(model, self.img_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
class PostTrainingDynamicQuantTest(QuantizationTestCase):
|
|
def test_single_layer(self):
|
|
r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
|
|
make sure it is swapped to nnqd.Linear which is the quantized version of
|
|
the module
|
|
"""
|
|
for dtype in [torch.qint8, torch.float16]:
|
|
model = SingleLayerLinearDynamicModel().eval()
|
|
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
|
|
qconfig_dict = {
|
|
'fc1': qconfig
|
|
}
|
|
prepare_dynamic(model, qconfig_dict)
|
|
convert_dynamic(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkDynamicQuantizedLinear(model.fc1, dtype)
|
|
self.checkScriptable(model, self.calib_data, check_save_load=True)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API - out of place version
|
|
base = SingleLayerLinearDynamicModel()
|
|
keys_before = set(list(base.state_dict().keys()))
|
|
model = quantize_dynamic(base, qconfig_dict)
|
|
checkQuantized(model)
|
|
keys_after = set(list(base.state_dict().keys()))
|
|
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
|
|
|
|
# in-place version
|
|
model = SingleLayerLinearDynamicModel()
|
|
quantize_dynamic(model, qconfig_dict, inplace=True)
|
|
checkQuantized(model)
|
|
|
|
# Test set qconfig
|
|
model = SingleLayerLinearDynamicModel()
|
|
quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype)
|
|
checkQuantized(model)
|
|
|
|
def test_two_layers(self):
|
|
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
|
|
`fc2`, and `fc1`is not quantized
|
|
"""
|
|
for dtype in [torch.qint8, torch.float16]:
|
|
model = TwoLayerLinearModel().eval()
|
|
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
|
|
qconfig_dict = {
|
|
'fc2': qconfig
|
|
}
|
|
prepare_dynamic(model, qconfig_dict)
|
|
|
|
convert_dynamic(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.fc1), torch.nn.Linear)
|
|
self.checkDynamicQuantizedLinear(model.fc2, dtype=dtype)
|
|
self.checkScriptable(model, self.calib_data, check_save_load=True)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict)
|
|
checkQuantized(model)
|
|
|
|
# Test set API
|
|
model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}, dtype=dtype)
|
|
checkQuantized(model)
|
|
|
|
def test_nested1(self):
|
|
r"""Test quantization for nested model, top level 'fc3' and
|
|
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
|
|
"""
|
|
for dtype in [torch.qint8, torch.float16]:
|
|
model = NestedModel().eval()
|
|
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
|
|
qconfig_dict = {
|
|
'fc3': qconfig,
|
|
'sub2.fc1': qconfig
|
|
}
|
|
|
|
prepare_dynamic(model, qconfig_dict)
|
|
convert_dynamic(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkLinear(model.sub1.fc)
|
|
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
|
|
self.checkLinear(model.sub2.fc2)
|
|
self.checkScriptable(model, self.calib_data, check_save_load=True)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
|
|
checkQuantized(model)
|
|
|
|
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}, dtype=dtype)
|
|
checkQuantized(model)
|
|
|
|
def test_nested2(self):
|
|
r"""Another test case for quantized, we will quantize all submodules
|
|
of submodule sub2
|
|
"""
|
|
for dtype in [torch.qint8, torch.float16]:
|
|
model = NestedModel().eval()
|
|
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
|
|
qconfig_dict = {
|
|
'fc3': qconfig,
|
|
'sub2': qconfig
|
|
}
|
|
prepare_dynamic(model, qconfig_dict)
|
|
|
|
convert_dynamic(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkLinear(model.sub1.fc)
|
|
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
|
|
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
|
|
self.checkScriptable(model, self.calib_data, check_save_load=True)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
|
|
checkQuantized(model)
|
|
|
|
# Test set API
|
|
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}, dtype=dtype)
|
|
checkQuantized(model)
|
|
|
|
def test_nested3(self):
|
|
r"""More complicated nested test case with child qconfig overrides
|
|
parent qconfig
|
|
"""
|
|
for dtype in [torch.qint8, torch.float16]:
|
|
model = NestedModel().eval()
|
|
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
|
|
qconfig_dynamic_dict = {
|
|
'fc3': qconfig,
|
|
'sub2': qconfig,
|
|
'sub2.fc1': qconfig
|
|
}
|
|
prepare_dynamic(model, qconfig_dynamic_dict)
|
|
|
|
convert_dynamic(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
|
|
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
|
|
self.checkScriptable(model, self.calib_data, check_save_load=True)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict)
|
|
checkQuantized(model)
|
|
|
|
# Test set API
|
|
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}, dtype=dtype)
|
|
checkQuantized(model)
|
|
|
|
def test_type_match_rule(self):
|
|
r"""Test quantization for nested model, top level 'fc3' and
|
|
'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized
|
|
"""
|
|
for dtype in [torch.qint8, torch.float16]:
|
|
model = NestedModel().eval()
|
|
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
|
|
qconfig_dict = {
|
|
'fc3': None,
|
|
'sub2.fc1': None,
|
|
torch.nn.Linear: qconfig
|
|
}
|
|
|
|
prepare_dynamic(model, qconfig_dict)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
convert_dynamic(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=dtype)
|
|
self.checkLinear(model.fc3)
|
|
self.checkLinear(model.sub2.fc1)
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data, check_save_load=True)
|
|
|
|
checkQuantized(model)
|
|
|
|
# test one line API
|
|
model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
|
|
checkQuantized(model)
|
|
|
|
def test_per_channel_quantize(self):
|
|
r"""Test quantization for per_channel dynamic quantization
|
|
"""
|
|
model = NestedModel().eval()
|
|
qconfig_dict = {
|
|
torch.nn.Linear: per_channel_dynamic_qconfig
|
|
}
|
|
|
|
prepare_dynamic(model, qconfig_dict)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
convert_dynamic(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=torch.qint8)
|
|
self.checkDynamicQuantizedLinear(model.fc3, dtype=torch.qint8)
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=torch.qint8)
|
|
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=torch.qint8)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data, check_save_load=True)
|
|
|
|
checkQuantized(model)
|
|
# test one line API
|
|
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
|
|
checkQuantized(model)
|
|
|
|
def test_quantized_rnn(self):
|
|
d_in, d_hid = 2, 2
|
|
model = LSTMDynamicModel().eval()
|
|
cell = model.lstm
|
|
|
|
# Replace parameter values s.t. the range of values is exactly
|
|
# 255, thus we will have 0 quantization error in the quantized
|
|
# GEMM call. This i s for testing purposes.
|
|
#
|
|
# Note that the current implementation does not support
|
|
# accumulation values outside of the range representable by a
|
|
# 16 bit integer, instead resulting in a saturated value. We
|
|
# must take care that in our test we do not end up with a dot
|
|
# product that overflows the int16 range, e.g.
|
|
# (255*127+255*127) = 64770. So, we hardcode the test values
|
|
# here and ensure a mix of signedness.
|
|
vals = [[100, -155],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155]]
|
|
if isinstance(cell, torch.nn.LSTM):
|
|
num_chunks = 4
|
|
vals = vals[:d_hid * num_chunks]
|
|
cell.weight_ih_l0 = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
cell.weight_hh_l0 = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
|
|
ref = copy.deepcopy(cell)
|
|
|
|
model_int8 = quantize_dynamic(model=model, dtype=torch.qint8)
|
|
model_fp16 = quantize_dynamic(model=model, dtype=torch.float16)
|
|
|
|
# Smoke test extra reprs
|
|
self.assertTrue('DynamicQuantizedLSTM' in str(model_int8))
|
|
self.assertTrue('DynamicQuantizedLSTM' in str(model_fp16))
|
|
cell_int8 = model_int8.lstm
|
|
cell_fp16 = model_fp16.lstm
|
|
|
|
assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \
|
|
'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
|
|
assert type(cell_fp16) == torch.nn.quantized.dynamic.LSTM, \
|
|
'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
|
|
|
|
niter = 10
|
|
x = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
|
|
|
|
h0_vals = [[-155, 100],
|
|
[-155, 155],
|
|
[100, -155]]
|
|
|
|
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
|
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
|
|
|
if isinstance(ref, torch.nn.LSTM):
|
|
hiddens = (hx, cx)
|
|
|
|
ref_out, ref_hid = ref(x, hiddens)
|
|
|
|
# Compare int8 quantized to unquantized
|
|
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
|
|
|
|
torch.testing.assert_allclose(output_int8, ref_out)
|
|
self.assertEqual(output_int8, ref_out)
|
|
for out_val, ref_val in zip(final_hiddens_int8, ref_hid):
|
|
torch.testing.assert_allclose(out_val, ref_val)
|
|
|
|
class ScriptWrapper(torch.nn.Module):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapper, self).__init__()
|
|
self.cell = cell
|
|
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
return self.cell(x, hiddens)
|
|
|
|
# TODO: TorchScript overloads don't work without this wrapper
|
|
cell_script = torch.jit.script(ScriptWrapper(cell_int8))
|
|
out_script, hid_script = cell_script(x, hiddens)
|
|
self.assertEqual(len(out_script), len(ref_out))
|
|
for out_val, ref_val in zip(out_script, ref_out):
|
|
torch.testing.assert_allclose(out_val, ref_val)
|
|
|
|
# Test save/load
|
|
b = io.BytesIO()
|
|
torch.jit.save(cell_script, b)
|
|
b.seek(0)
|
|
loaded = torch.jit.load(b)
|
|
out_loaded, hid_loaded = loaded(x, hiddens)
|
|
for loaded_val, ref_val in zip(out_loaded, ref_out):
|
|
torch.testing.assert_allclose(loaded_val, ref_val)
|
|
|
|
# Compare fp16 quantized to unquantized
|
|
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
|
|
|
|
torch.testing.assert_allclose(output_fp16, ref_out)
|
|
self.assertEqual(output_fp16, ref_out)
|
|
for out, ref_val in zip(final_hiddens_fp16, ref_hid):
|
|
torch.testing.assert_allclose(out, ref_val)
|
|
|
|
# Test tracing
|
|
# TODO: TorchScript overloads don't work without this wrapper
|
|
cell_trace = torch.jit.trace(ScriptWrapper(cell_int8), (x, (hx, cx)))
|
|
out_script, hid_script = cell_trace(x, hiddens)
|
|
for out_val, ref_val in zip(out_script, ref_out):
|
|
torch.testing.assert_allclose(out_val, ref_val)
|
|
|
|
# print(cell_trace.code)
|
|
|
|
# Test save/load
|
|
b = io.BytesIO()
|
|
torch.jit.save(cell_trace, b)
|
|
b.seek(0)
|
|
loaded = torch.jit.load(b)
|
|
out_loaded, hid_loaded = loaded(x, hiddens)
|
|
for loaded_val, ref_val in zip(out_loaded, ref_out):
|
|
torch.testing.assert_allclose(loaded_val, ref_val)
|
|
|
|
# Compare fp16 quantized to unquantized
|
|
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
|
|
|
|
torch.testing.assert_allclose(output_fp16, ref_out)
|
|
self.assertEqual(output_fp16, ref_out)
|
|
for out, ref_val in zip(final_hiddens_fp16, ref_hid):
|
|
torch.testing.assert_allclose(out, ref_val)
|
|
|
|
class ScriptWrapperPacked(torch.nn.Module):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapperPacked, self).__init__()
|
|
self.cell = cell
|
|
|
|
def forward(self,
|
|
x, # type: PackedSequence
|
|
hiddens # type: Tuple[torch.Tensor, torch.Tensor]
|
|
):
|
|
# type: (...) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]
|
|
return self.cell(x, hiddens)
|
|
|
|
cell_packed = torch.jit.script(ScriptWrapperPacked(cell_int8))
|
|
packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2]))
|
|
ref_out_packed, ref_hid_packed = ref(packed_input, hiddens)
|
|
output_packed, hiddens_packed = cell_packed(packed_input, hiddens)
|
|
|
|
for packed_val, ref_val in zip(output_packed, ref_out_packed):
|
|
if isinstance(packed_val, torch.Tensor):
|
|
torch.testing.assert_allclose(packed_val, ref_val)
|
|
else:
|
|
self.assertEqual(packed_val, ref_val)
|
|
|
|
# Test save/load
|
|
b = io.BytesIO()
|
|
torch.jit.save(cell_packed, b)
|
|
b.seek(0)
|
|
loaded_packed = torch.jit.load(b)
|
|
out_loaded_packed, hid_loaded_packed = loaded_packed(packed_input, hiddens)
|
|
for packed_val, ref_val in zip(out_loaded_packed, ref_out_packed):
|
|
if isinstance(packed_val, torch.Tensor):
|
|
torch.testing.assert_allclose(packed_val, ref_val)
|
|
else:
|
|
self.assertEqual(packed_val, ref_val)
|
|
|
|
# Test default instantiation
|
|
seq_len = 128
|
|
batch = 16
|
|
input_size = 3
|
|
hidden_size = 7
|
|
num_layers = 2
|
|
bias = True
|
|
bidirectional = False
|
|
|
|
x = torch.rand(seq_len, batch, input_size)
|
|
h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size)
|
|
c = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size)
|
|
|
|
dtype = torch.qint8
|
|
|
|
cell_dq = torch.nn.quantized.dynamic.LSTM(input_size=input_size,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
bias=bias,
|
|
batch_first=False,
|
|
dropout=0.0,
|
|
bidirectional=bidirectional,
|
|
dtype=dtype)
|
|
|
|
y, (h, c) = cell_dq(x, (h, c))
|
|
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
class EagerModeQuantizationAwareTrainingTest(QuantizationTestCase):
|
|
def test_manual(self):
|
|
model = ManualLinearQATModel()
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
test_only_train_fn(model, self.train_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.fc1), nnq.Linear)
|
|
self.assertEqual(type(model.fc2), nnq.Linear)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = quantize_qat(ManualLinearQATModel(), test_only_train_fn,
|
|
self.train_data)
|
|
checkQuantized(model)
|
|
|
|
def test_eval_only_fake_quant(self):
|
|
r"""Using FakeQuant in evaluation only mode,
|
|
this is useful for estimating accuracy loss when we quantize the
|
|
network
|
|
"""
|
|
model = ManualLinearQATModel()
|
|
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
model.eval()
|
|
test_only_eval_fn(model, self.calib_data)
|
|
|
|
def test_conv_linear(self):
|
|
model = ManualConvLinearQATModel()
|
|
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
test_only_train_fn(model, self.img_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.fc1), nnq.Linear)
|
|
self.assertEqual(type(model.fc2), nnq.Linear)
|
|
test_only_eval_fn(model, self.img_data)
|
|
self.checkScriptable(model, self.img_data)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = ManualConvLinearQATModel()
|
|
model = quantize_qat(model, test_only_train_fn, self.img_data)
|
|
checkQuantized(model)
|
|
|
|
|
|
@unittest.skipUnless(
|
|
'fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.",
|
|
)
|
|
class GraphModePostTrainingQuantTest(QuantizationTestCase):
|
|
def test_single_linear(self):
|
|
r"""Compare the result of quantizing single linear layer in
|
|
eager mode and graph mode
|
|
"""
|
|
# eager mode
|
|
annotated_linear_model = AnnotatedSingleLayerLinearModel()
|
|
linear_model = SingleLayerLinearModel()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
|
|
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
|
|
model_eager = quantize(annotated_linear_model, test_only_eval_fn,
|
|
self.calib_data)
|
|
|
|
qconfig_dict = {'': default_qconfig}
|
|
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(linear_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_script(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
test_only_eval_fn,
|
|
[self.calib_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
def test_observer_with_ignored_function(self):
|
|
r"""Test observers with ignored function and make sure it works in
|
|
graph mode
|
|
"""
|
|
# eager mode
|
|
annotated_linear_model = AnnotatedSingleLayerLinearModel().eval()
|
|
for qconfig in [
|
|
QConfig(
|
|
activation=default_observer,
|
|
weight=default_weight_observer),
|
|
QConfig(
|
|
activation=default_histogram_observer,
|
|
weight=default_weight_observer),
|
|
QConfig(
|
|
activation=default_observer,
|
|
weight=default_per_channel_weight_observer),
|
|
]:
|
|
annotated_linear_model.qconfig = qconfig
|
|
linear_model = SingleLayerLinearModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
|
|
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
|
|
model_eager = quantize(annotated_linear_model, test_only_eval_fn,
|
|
self.calib_data)
|
|
|
|
qconfig_dict = {'': qconfig}
|
|
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(linear_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_script(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
test_only_eval_fn,
|
|
[self.calib_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
def test_conv(self):
|
|
r"""Compare the result of quantizing conv layer in
|
|
eager mode and graph mode
|
|
"""
|
|
# eager mode
|
|
annotated_conv_model = AnnotatedConvModel().eval()
|
|
conv_model = ConvModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach())
|
|
model_eager = quantize(annotated_conv_model, default_eval_fn,
|
|
self.img_data)
|
|
qconfig_dict = {'': default_qconfig}
|
|
model_traced = torch.jit.trace(conv_model, self.img_data[0][0])
|
|
model_script = torch.jit.script(conv_model)
|
|
result_eager = model_eager(self.img_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_script(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
default_eval_fn,
|
|
[self.img_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.img_data[0][0]), result_eager)
|
|
|
|
@unittest.skip("This doesn't work right now, re-enable after fold_convbn is fixed")
|
|
def test_conv_bn(self):
|
|
r"""Compare the result of quantizing conv + bn layer in
|
|
eager mode and graph mode
|
|
"""
|
|
# eager mode
|
|
conv_model = AnnotatedConvBnModel().eval()
|
|
conv_model_to_script = ConvBnModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
conv_model_to_script.conv.weight = torch.nn.Parameter(conv_model.conv.weight.detach())
|
|
fuse_modules(conv_model, ['conv', 'bn'], inplace=True)
|
|
model_eager = quantize(conv_model, default_eval_fn,
|
|
self.img_data)
|
|
qconfig_dict = {
|
|
'': default_qconfig
|
|
}
|
|
model_script = quantize_script(
|
|
torch.jit.script(conv_model_to_script),
|
|
qconfig_dict,
|
|
default_eval_fn,
|
|
[self.img_data],
|
|
inplace=False)
|
|
result_eager = model_eager(self.img_data[0][0])
|
|
result_script = model_script(self.img_data[0][0])
|
|
self.assertEqual(result_eager, result_script)
|
|
|
|
def test_nested(self):
|
|
# Eager mode
|
|
eager_model = AnnotatedNestedModel()
|
|
|
|
# Graph mode
|
|
script_model = NestedModel()
|
|
# Copy weights for eager_model
|
|
script_model.sub1.fc.weight = torch.nn.Parameter(eager_model.sub1.fc.weight.detach())
|
|
script_model.sub1.fc.bias = torch.nn.Parameter(eager_model.sub1.fc.bias.detach())
|
|
script_model.sub2.fc1.weight = torch.nn.Parameter(eager_model.sub2.fc1.module.weight.detach())
|
|
script_model.sub2.fc1.bias = torch.nn.Parameter(eager_model.sub2.fc1.module.bias.detach())
|
|
script_model.sub2.fc2.weight = torch.nn.Parameter(eager_model.sub2.fc2.weight.detach())
|
|
script_model.sub2.fc2.bias = torch.nn.Parameter(eager_model.sub2.fc2.bias.detach())
|
|
script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach())
|
|
script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach())
|
|
|
|
model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data)
|
|
qconfig_dict = {
|
|
'sub2.fc1': default_per_channel_qconfig,
|
|
'fc3': default_qconfig
|
|
}
|
|
model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(script_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_script(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
test_only_eval_fn,
|
|
[self.calib_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
|
|
class FunctionalModuleTest(QuantizationTestCase):
|
|
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
|
|
@given(train_mode=st.booleans())
|
|
def test_functional_module(self, train_mode):
|
|
model = ModelWithFunctionals()
|
|
x = torch.rand(10, 1, dtype=torch.float)
|
|
xq = torch.quantize_per_tensor(x, 0.01, 30, torch.quint8)
|
|
self.checkScriptable(model, [(x, x)], check_save_load=True)
|
|
if train_mode:
|
|
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
|
|
model = prepare_qat(model)
|
|
else:
|
|
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
|
|
model = prepare(model)
|
|
# Check if observers and quant/dequant nodes are inserted
|
|
self.checkNoPrepModules(model)
|
|
self.checkObservers(model)
|
|
# Calibrate
|
|
model(xq.dequantize())
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.checkNoPrepModules(model)
|
|
self.assertEqual(type(model.myadd), torch.nn.quantized.QFunctional)
|
|
self.assertEqual(type(model.mycat), torch.nn.quantized.QFunctional)
|
|
self.assertEqual(type(model.myadd_relu), torch.nn.quantized.QFunctional)
|
|
|
|
checkQuantized(model)
|
|
self.checkScriptable(model, [(xq, xq)], check_save_load=True)
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
class FusionTest(QuantizationTestCase):
|
|
def test_fuse_module_train(self):
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
# Test step by step fusion
|
|
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
|
|
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
|
|
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
|
"Fused Conv + BN + Relu first layer")
|
|
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
|
"Fused Conv + BN + Relu (skipped BN)")
|
|
self.assertEqual(type(model.relu1), torch.nn.Identity,
|
|
"Fused Conv + BN + Relu (skipped Relu)")
|
|
|
|
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
|
|
"Fused submodule Conv + BN")
|
|
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
|
|
"Fused submodule Conv + BN (skipped BN)")
|
|
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
|
|
"Non-fused submodule Conv")
|
|
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
|
"Non-fused submodule ReLU")
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
|
|
checkQAT(model)
|
|
test_only_train_fn(model, self.img_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
test_only_eval_fn(model, self.img_data)
|
|
checkQuantized(model)
|
|
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
|
['sub1.conv', 'sub1.bn']])
|
|
model = quantize_qat(model, test_only_train_fn, self.img_data)
|
|
checkQuantized(model)
|
|
|
|
|
|
def test_fuse_module_eval(self):
|
|
model = ModelForFusion(default_qconfig)
|
|
model.eval()
|
|
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'] ,
|
|
['sub1.conv', 'sub1.bn']])
|
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
|
"Fused Conv + BN + Relu first layer (BN is folded)")
|
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
|
"Fused Conv + BN + Relu (Conv + folded BN only)")
|
|
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
|
"Fused Conv + BN + Relu second layer (Relu only)")
|
|
self.assertEqual(type(model.bn1), nn.Identity,
|
|
"Fused Conv + BN + Relu second layer (Skipped BN)")
|
|
self.assertEqual(type(model.relu1), nn.Identity,
|
|
"Fused Conv + BN + Relu second layer (Skipped Relu)")
|
|
|
|
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
|
|
"Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity,
|
|
"Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
|
|
"Non-fused submodule Conv")
|
|
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
|
"Non-fused submodule ReLU")
|
|
|
|
model = prepare(model)
|
|
self.checkObservers(model)
|
|
test_only_eval_fn(model, self.img_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
test_only_eval_fn(model, self.img_data)
|
|
checkQuantized(model)
|
|
|
|
model = ModelForFusion(default_qconfig).eval()
|
|
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
|
['sub1.conv', 'sub1.bn']])
|
|
model = quantize(model, test_only_eval_fn, self.img_data)
|
|
checkQuantized(model)
|
|
|
|
def test_fusion_sequential_model_train(self):
|
|
model = ModelWithSequentialFusion().train()
|
|
model.to(torch.float)
|
|
fuse_modules(model, [['conv1', 'relu1'] ,
|
|
['features.0.0', 'features.0.1', 'features.0.2'],
|
|
['features.1.0', 'features.1.1', 'features.1.2'],
|
|
['features.2.0', 'features.2.1', 'features.2.2'],
|
|
['classifier.0', 'classifier.1']], inplace=True)
|
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
|
"Fused Conv + Relu: nni.ConvReLU2d")
|
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
|
"Fused Conv + Relu: Conv2d")
|
|
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
|
"Fused Conv + Relu: Relu")
|
|
self.assertEqual(type(model.relu1), nn.Identity,
|
|
"Fused Conv + Relu: Identity")
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
|
|
"Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
|
"Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
|
"Non-fused submodule Conv")
|
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
model.qconfig = default_qat_qconfig
|
|
prepare_qat(model, inplace=True)
|
|
self.checkObservers(model)
|
|
model(self.img_data[0][0])
|
|
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
|
|
"Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
|
"Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
|
"Non-fused submodule Conv")
|
|
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
|
|
checkQAT(model)
|
|
model(self.img_data[1][0])
|
|
convert(model, inplace=True)
|
|
model(self.img_data[1][0])
|
|
self.checkModelWithSequentialQuantized(model)
|
|
|
|
def test_fusion_sequential_model_eval(self):
|
|
model = ModelWithSequentialFusion().eval()
|
|
model.to(torch.float)
|
|
fuse_modules(model, [['conv1', 'relu1'] ,
|
|
['features.0.0', 'features.0.1', 'features.0.2'],
|
|
['features.1.0', 'features.1.1', 'features.1.2'],
|
|
['features.2.0', 'features.2.1', 'features.2.2'],
|
|
['classifier.0', 'classifier.1']], inplace=True)
|
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
|
"Fused Conv + Relu: nni.ConvReLU2d")
|
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
|
"Fused Conv + Relu: Conv2d")
|
|
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
|
"Fused Conv + Relu: Relu")
|
|
self.assertEqual(type(model.relu1), nn.Identity,
|
|
"Fused Conv + Relu: Identity")
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
|
|
"Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
|
"Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
|
"Non-fused submodule Conv")
|
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
model.qconfig = default_qconfig
|
|
prepare(model, inplace=True)
|
|
self.checkObservers(model)
|
|
model(self.img_data[0][0])
|
|
convert(model, inplace=True)
|
|
model(self.img_data[1][0])
|
|
self.checkModelWithSequentialQuantized(model)
|
|
|
|
def checkModelWithSequentialQuantized(self, model):
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity)
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity)
|
|
self.assertEqual(type(model.classifier[0]), nniq.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
|
|
|
|
class ObserverTest(QuantizationTestCase):
|
|
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
|
|
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
|
|
reduce_range=st.booleans())
|
|
def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
|
|
# reduce_range cannot be true for symmetric quantization with uint8
|
|
if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric:
|
|
reduce_range = False
|
|
ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range),
|
|
MovingAverageMinMaxObserver(averaging_constant=0.5,
|
|
dtype=qdtype,
|
|
qscheme=qscheme,
|
|
reduce_range=reduce_range)]
|
|
for myobs in ObserverList:
|
|
# Calculate Qparams should return with a warning for observers with no data
|
|
qparams = myobs.calculate_qparams()
|
|
if type(myobs) == MinMaxObserver:
|
|
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
|
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
|
|
else:
|
|
# Moving average of min/max for x and y matches that of
|
|
# extreme values for x/y used for minmax observer
|
|
x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
|
y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0])
|
|
|
|
result = myobs(x)
|
|
result = myobs(y)
|
|
self.assertEqual(result, y)
|
|
self.assertEqual(myobs.min_val, 1.0)
|
|
self.assertEqual(myobs.max_val, 8.0)
|
|
qparams = myobs.calculate_qparams()
|
|
if reduce_range:
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
ref_scale = 0.062745 * 255 / 127
|
|
ref_zero_point = 0 if qdtype is torch.qint8 else 128
|
|
else:
|
|
ref_scale = 0.0313725 * 255 / 127
|
|
ref_zero_point = -64 if qdtype is torch.qint8 else 0
|
|
else:
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
ref_scale = 0.062745
|
|
ref_zero_point = 0 if qdtype is torch.qint8 else 128
|
|
else:
|
|
ref_scale = 0.0313725
|
|
ref_zero_point = -128 if qdtype is torch.qint8 else 0
|
|
self.assertEqual(qparams[1].item(), ref_zero_point)
|
|
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
|
|
state_dict = myobs.state_dict()
|
|
b = io.BytesIO()
|
|
torch.save(state_dict, b)
|
|
b.seek(0)
|
|
loaded_dict = torch.load(b)
|
|
for key in state_dict:
|
|
self.assertEqual(state_dict[key], loaded_dict[key])
|
|
loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
|
|
loaded_obs.load_state_dict(loaded_dict)
|
|
loaded_qparams = loaded_obs.calculate_qparams()
|
|
self.assertEqual(myobs.min_val, loaded_obs.min_val)
|
|
self.assertEqual(myobs.max_val, loaded_obs.max_val)
|
|
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
|
|
|
|
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
|
|
qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)),
|
|
ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
|
|
def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
|
|
# reduce_range cannot be true for symmetric quantization with uint8
|
|
if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
|
|
reduce_range = False
|
|
ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range,
|
|
ch_axis=ch_axis,
|
|
dtype=qdtype,
|
|
qscheme=qscheme),
|
|
MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5,
|
|
reduce_range=reduce_range,
|
|
ch_axis=ch_axis,
|
|
dtype=qdtype,
|
|
qscheme=qscheme)]
|
|
|
|
for myobs in ObserverList:
|
|
# Calculate qparams should work for empty observers
|
|
qparams = myobs.calculate_qparams()
|
|
x = torch.tensor(
|
|
[
|
|
[[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
|
|
[[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
|
|
]
|
|
)
|
|
if type(myobs) == MovingAveragePerChannelMinMaxObserver:
|
|
# Scaling the input tensor to model change in min/max values
|
|
# across batches
|
|
result = myobs(0.5 * x)
|
|
result = myobs(1.5 * x)
|
|
self.assertEqual(result, 1.5 * x)
|
|
else:
|
|
result = myobs(x)
|
|
self.assertEqual(result, x)
|
|
|
|
qparams = myobs.calculate_qparams()
|
|
ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]]
|
|
ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
|
|
per_channel_symmetric_ref_scales = [
|
|
[0.04705882, 0.06274509],
|
|
[0.03921569, 0.0627451],
|
|
[0.04705882, 0.0627451],
|
|
[0.05490196, 0.0627451],
|
|
]
|
|
per_channel_affine_ref_scales = [
|
|
[0.02352941, 0.04705882],
|
|
[0.03529412, 0.03137255],
|
|
[0.03921569, 0.03137255],
|
|
[0.04313726, 0.04313726],
|
|
]
|
|
per_channel_affine_qint8_zp = [
|
|
[-128, -43],
|
|
[-15, -128],
|
|
[-26, -128],
|
|
[-35, -58],
|
|
]
|
|
per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]
|
|
|
|
self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
|
|
self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
|
|
if qscheme == torch.per_channel_symmetric:
|
|
ref_scales = per_channel_symmetric_ref_scales[ch_axis]
|
|
ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
|
|
else:
|
|
ref_scales = per_channel_affine_ref_scales[ch_axis]
|
|
ref_zero_points = (
|
|
per_channel_affine_qint8_zp[ch_axis]
|
|
if qdtype is torch.qint8
|
|
else per_channel_affine_quint8_zp[ch_axis]
|
|
)
|
|
|
|
if reduce_range:
|
|
ref_scales = [s * 255 / 127 for s in ref_scales]
|
|
ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
|
|
|
|
self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype)))
|
|
self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))
|
|
|
|
# Test for serializability
|
|
state_dict = myobs.state_dict()
|
|
b = io.BytesIO()
|
|
torch.save(state_dict, b)
|
|
b.seek(0)
|
|
loaded_dict = torch.load(b)
|
|
for key in state_dict:
|
|
self.assertEqual(state_dict[key], loaded_dict[key])
|
|
loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme)
|
|
loaded_obs.load_state_dict(loaded_dict)
|
|
loaded_qparams = loaded_obs.calculate_qparams()
|
|
self.assertEqual(myobs.min_vals, loaded_obs.min_vals)
|
|
self.assertEqual(myobs.max_vals, loaded_obs.max_vals)
|
|
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
|
|
|
|
def test_observer_scriptable(self):
|
|
obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver()]
|
|
for obs in obs_list:
|
|
scripted = torch.jit.script(obs)
|
|
|
|
x = torch.rand(3, 4)
|
|
obs(x)
|
|
scripted(x)
|
|
|
|
self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
|
|
|
|
buf = io.BytesIO()
|
|
torch.jit.save(scripted, buf)
|
|
buf.seek(0)
|
|
loaded = torch.jit.load(buf)
|
|
self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
|
|
|
|
def test_no_qconfig_propagation(self):
|
|
model = ModelWithNoQconfigPropagation()
|
|
model.qconfig = torch.quantization.default_qconfig
|
|
|
|
model = prepare(model)
|
|
self.assertTrue(hasattr(model.fc1, 'qconfig'),
|
|
"QConfig is expected to propagate")
|
|
self.assertFalse(hasattr(model.no_quant_module, 'qconfig'),
|
|
"QConfig is expected to NOT propagate")
|
|
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
class RecordHistogramObserverTest(QuantizationTestCase):
|
|
def test_record_observer(self):
|
|
model = AnnotatedSingleLayerLinearModel()
|
|
model.qconfig = default_debug_qconfig
|
|
model = prepare(model)
|
|
# run the evaluation and dump all tensors
|
|
test_only_eval_fn(model, self.calib_data)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
observer_dict = {}
|
|
get_observer_dict(model, observer_dict)
|
|
|
|
self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
|
|
'observer is not recorded in the dict')
|
|
self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data))
|
|
self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0], model(self.calib_data[0][0]))
|
|
|
|
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
|
|
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)))
|
|
def test_observer_scriptable(self, qdtype, qscheme):
|
|
obs = RecordingObserver(dtype=qdtype, qscheme=qscheme)
|
|
scripted = torch.jit.script(obs)
|
|
|
|
x = torch.rand(3, 4)
|
|
obs(x)
|
|
scripted(x)
|
|
self.assertTrue(torch.equal(obs.get_tensor_value()[0], scripted.get_tensor_value()[0]))
|
|
buf = io.BytesIO()
|
|
torch.jit.save(scripted, buf)
|
|
buf.seek(0)
|
|
loaded = torch.jit.load(buf)
|
|
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
|
|
|
|
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
|
|
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
|
|
reduce_range=st.booleans())
|
|
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
|
|
myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
|
|
# Calculate qparams should work for empty observers
|
|
qparams = myobs.calculate_qparams()
|
|
x = torch.tensor([2.0, 3.0, 4.0, 5.0])
|
|
y = torch.tensor([5.0, 6.0, 7.0, 8.0])
|
|
myobs(x)
|
|
myobs(y)
|
|
self.assertEqual(myobs.min_val, 2.0)
|
|
self.assertEqual(myobs.max_val, 8.0)
|
|
self.assertEqual(myobs.histogram, [2., 3., 3.])
|
|
|
|
qparams = myobs.calculate_qparams()
|
|
|
|
if reduce_range:
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
ref_scale = 0.0470588 * 255 / 127
|
|
ref_zero_point = 0 if qdtype is torch.qint8 else 128
|
|
else:
|
|
ref_scale = 0.0235294 * 255 / 127
|
|
ref_zero_point = -64 if qdtype is torch.qint8 else 0
|
|
else:
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
ref_scale = 0.0470588
|
|
ref_zero_point = 0 if qdtype is torch.qint8 else 128
|
|
else:
|
|
ref_scale = 0.0235294
|
|
ref_zero_point = -128 if qdtype is torch.qint8 else 0
|
|
|
|
self.assertEqual(qparams[1].item(), ref_zero_point)
|
|
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
|
|
# Test for serializability
|
|
state_dict = myobs.state_dict()
|
|
b = io.BytesIO()
|
|
torch.save(state_dict, b)
|
|
b.seek(0)
|
|
loaded_dict = torch.load(b)
|
|
for key in state_dict:
|
|
self.assertEqual(state_dict[key], loaded_dict[key])
|
|
loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
|
|
loaded_obs.load_state_dict(loaded_dict)
|
|
loaded_qparams = loaded_obs.calculate_qparams()
|
|
self.assertEqual(myobs.min_val, loaded_obs.min_val)
|
|
self.assertEqual(myobs.max_val, loaded_obs.max_val)
|
|
self.assertEqual(myobs.histogram, loaded_obs.histogram)
|
|
self.assertEqual(myobs.bins, loaded_obs.bins)
|
|
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|