mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162 Approved by: https://github.com/Skylion007
9844 lines
394 KiB
Python
9844 lines
394 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
# ruff: noqa: F841
|
|
|
|
from collections import OrderedDict
|
|
import contextlib
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import torch.ao.nn.quantized as nnq
|
|
import torch.ao.nn.quantized.reference as nnqr
|
|
import torch.ao.nn.quantized.dynamic as nnqd
|
|
import torch.ao.nn.intrinsic as nni
|
|
import torch.ao.nn.intrinsic.quantized as nniq
|
|
import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
|
|
import torch.multiprocessing as mp
|
|
from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
|
|
|
|
# graph mode quantization based on fx
|
|
from torch.ao.quantization.quantize_fx import (
|
|
prepare_fx,
|
|
convert_fx,
|
|
convert_to_reference_fx,
|
|
_convert_to_reference_decomposed_fx,
|
|
prepare_qat_fx,
|
|
fuse_fx,
|
|
)
|
|
|
|
|
|
from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
|
|
|
|
from torch.ao.quantization.fx.match_utils import (
|
|
_is_match,
|
|
MatchAllNode,
|
|
)
|
|
|
|
from torch.ao.quantization import (
|
|
QuantType,
|
|
)
|
|
from torch.ao.quantization.quant_type import _get_quant_type_to_str
|
|
|
|
from torch.ao.quantization import (
|
|
QuantStub,
|
|
DeQuantStub,
|
|
QuantWrapper,
|
|
default_qconfig,
|
|
default_dynamic_qconfig,
|
|
default_per_channel_qconfig,
|
|
default_qat_qconfig,
|
|
default_reuse_input_qconfig,
|
|
default_symmetric_qnnpack_qconfig,
|
|
default_symmetric_qnnpack_qat_qconfig,
|
|
per_channel_dynamic_qconfig,
|
|
float16_dynamic_qconfig,
|
|
float16_static_qconfig,
|
|
float_qparams_weight_only_qconfig,
|
|
float_qparams_weight_only_qconfig_4bit,
|
|
get_default_qconfig,
|
|
get_default_qat_qconfig,
|
|
get_default_qconfig_mapping,
|
|
get_default_qat_qconfig_mapping,
|
|
fuse_modules,
|
|
fuse_modules_qat,
|
|
prepare,
|
|
prepare_qat,
|
|
convert,
|
|
quantize_dynamic,
|
|
default_placeholder_observer,
|
|
default_weight_observer,
|
|
PerChannelMinMaxObserver,
|
|
FixedQParamsFakeQuantize,
|
|
FixedQParamsObserver,
|
|
FusedMovingAvgObsFakeQuantize,
|
|
FakeQuantize,
|
|
MovingAverageMinMaxObserver,
|
|
HistogramObserver,
|
|
ReuseInputObserver,
|
|
QConfig,
|
|
default_embedding_qat_qconfig,
|
|
)
|
|
|
|
from torch.ao.quantization.backend_config import (
|
|
get_fbgemm_backend_config,
|
|
get_qnnpack_backend_config,
|
|
BackendConfig,
|
|
BackendPatternConfig,
|
|
DTypeConfig,
|
|
DTypeWithConstraints,
|
|
ObservationType
|
|
)
|
|
from torch.ao.quantization.backend_config.native import (
|
|
get_test_only_legacy_native_backend_config,
|
|
)
|
|
|
|
from torch.ao.quantization.qconfig_mapping import (
|
|
_get_symmetric_qnnpack_qconfig_mapping,
|
|
_get_symmetric_qnnpack_qat_qconfig_mapping,
|
|
_GLOBAL_DICT_KEY,
|
|
_MODULE_NAME_DICT_KEY,
|
|
_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY,
|
|
_MODULE_NAME_REGEX_DICT_KEY,
|
|
_OBJECT_TYPE_DICT_KEY,
|
|
QConfigMapping,
|
|
)
|
|
|
|
from torch.ao.quantization.fx.qconfig_mapping_utils import (
|
|
_get_object_type_qconfig,
|
|
_get_module_name_qconfig,
|
|
_get_module_name_regex_qconfig,
|
|
_maybe_adjust_qconfig_for_module_name_object_type_order,
|
|
)
|
|
|
|
from torch.ao.quantization.fx.pattern_utils import (
|
|
_DEFAULT_FUSION_PATTERNS,
|
|
_DEFAULT_QUANTIZATION_PATTERNS,
|
|
_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP,
|
|
_DEFAULT_OUTPUT_OBSERVER_MAP,
|
|
_register_fusion_pattern,
|
|
_register_quant_pattern,
|
|
get_default_output_activation_post_process_map
|
|
)
|
|
|
|
from torch.ao.quantization.fx.custom_config import (
|
|
STANDALONE_MODULE_NAME_DICT_KEY,
|
|
STANDALONE_MODULE_CLASS_DICT_KEY,
|
|
FLOAT_TO_OBSERVED_DICT_KEY,
|
|
OBSERVED_TO_QUANTIZED_DICT_KEY,
|
|
NON_TRACEABLE_MODULE_NAME_DICT_KEY,
|
|
NON_TRACEABLE_MODULE_CLASS_DICT_KEY,
|
|
INPUT_QUANTIZED_INDEXES_DICT_KEY,
|
|
OUTPUT_QUANTIZED_INDEXES_DICT_KEY,
|
|
PRESERVED_ATTRIBUTES_DICT_KEY,
|
|
FuseCustomConfig,
|
|
ConvertCustomConfig,
|
|
PrepareCustomConfig,
|
|
StandaloneModuleConfigEntry,
|
|
)
|
|
import torch.ao.quantization.fx.lstm_utils
|
|
|
|
from torch.ao.quantization.fx.utils import (
|
|
_reroute_tuple_getitem_pattern,
|
|
NodeInfo,
|
|
)
|
|
|
|
from torch.ao.quantization.fake_quantize import (
|
|
default_fixed_qparams_range_0to1_fake_quant,
|
|
default_fixed_qparams_range_neg1to1_fake_quant,
|
|
)
|
|
|
|
from torch.ao.quantization.observer import (
|
|
default_fixed_qparams_range_0to1_observer,
|
|
default_fixed_qparams_range_neg1to1_observer,
|
|
MinMaxObserver,
|
|
_is_activation_post_process,
|
|
)
|
|
|
|
# test utils
|
|
from hypothesis import given, settings
|
|
from hypothesis import strategies as st
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
|
|
from torch.testing._internal.common_quantization import (
|
|
LinearReluLinearModel,
|
|
LinearReluModel,
|
|
LinearBnLeakyReluModel,
|
|
LinearTanhModel,
|
|
ConvBnAddReluModel,
|
|
QuantizationTestCase,
|
|
skipIfNoFBGEMM,
|
|
skipIfNoQNNPACK,
|
|
skip_if_no_torchvision,
|
|
train_one_epoch,
|
|
run_ddp,
|
|
test_only_eval_fn,
|
|
test_only_train_fn,
|
|
ModelForConvTransposeBNFusion,
|
|
get_supported_device_types,
|
|
skipIfNoONEDNN,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantization import (
|
|
LinearModelWithSubmodule,
|
|
ResNetBase,
|
|
RNNDynamicModel,
|
|
RNNCellDynamicModel,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantized import (
|
|
supported_qengines,
|
|
override_qengines,
|
|
override_quantized_engine,
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
TemporaryFileName,
|
|
IS_ARM64,
|
|
skipIfTorchDynamo,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantization import NodeSpec as ns
|
|
|
|
from torch.testing import FileCheck
|
|
|
|
import copy
|
|
import itertools
|
|
import operator
|
|
import unittest
|
|
import io
|
|
from typing import Callable, Optional
|
|
|
|
class BinaryOp(torch.nn.Module):
|
|
def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
|
|
""" ibinary_op means inplace binary op
|
|
"""
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
|
|
self.is_scalar = is_scalar
|
|
self.op = ibinary_op if ibinary_op and is_inplace else binary_op
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = 3 if self.is_scalar else self.conv2(y)
|
|
# x = x + y
|
|
x = self.op(x, y)
|
|
# x = y + x
|
|
x = self.op(y, x)
|
|
return x
|
|
|
|
class BinaryOpNonQuantizedInput(torch.nn.Module):
|
|
def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
|
|
""" ibinary_op means inplace binary op
|
|
"""
|
|
super().__init__()
|
|
self.is_scalar = is_scalar
|
|
self.op = ibinary_op if ibinary_op and is_inplace else binary_op
|
|
|
|
def forward(self, x, y):
|
|
y = 3 if self.is_scalar else y
|
|
x = self.op(x, y)
|
|
return x
|
|
|
|
class BinaryOpRelu(torch.nn.Module):
|
|
def __init__(self, binary_op, ibinary_op, is_inplace, relu_callable,
|
|
is_scalar):
|
|
""" ibinary_op means inplace binary op
|
|
"""
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
|
|
self.op = ibinary_op if ibinary_op and is_inplace else binary_op
|
|
self.relu_callable = relu_callable
|
|
self.is_scalar = is_scalar
|
|
if relu_callable is torch.nn.ReLU:
|
|
self.relu = torch.nn.ReLU()
|
|
else:
|
|
self.relu = relu_callable
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = 3 if self.is_scalar else self.conv2(y)
|
|
x = self.op(x, y)
|
|
x = self.relu(x)
|
|
x = self.op(y, x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
@torch.fx.wrap
|
|
def _user_func_with_complex_return_type(x):
|
|
return list(torch.split(x, 1, 1))
|
|
|
|
class TestFuseFx(QuantizationTestCase):
|
|
def test_fuse_conv_bn_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1d = nn.Conv1d(1, 1, 1)
|
|
self.conv2d = nn.Conv2d(1, 1, 1)
|
|
self.conv3d = nn.Conv3d(1, 1, 1)
|
|
self.bn1d = nn.BatchNorm1d(1)
|
|
self.bn2d = nn.BatchNorm2d(1)
|
|
self.bn3d = nn.BatchNorm3d(1)
|
|
self.conv1d2 = nn.Conv1d(1, 1, 1)
|
|
self.conv2d2 = nn.Conv2d(1, 1, 1)
|
|
self.conv3d2 = nn.Conv3d(1, 1, 1)
|
|
self.bn1d2 = nn.BatchNorm1d(1)
|
|
self.bn2d2 = nn.BatchNorm2d(1)
|
|
self.bn3d2 = nn.BatchNorm3d(1)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv1d(x)
|
|
x = self.bn1d(x)
|
|
x = self.conv2d(x)
|
|
x = self.bn2d(x)
|
|
x = self.conv3d(x)
|
|
x = self.bn3d(x)
|
|
x = self.conv1d2(x)
|
|
x = self.bn1d2(x)
|
|
x = self.relu(x)
|
|
x = self.conv2d2(x)
|
|
x = self.bn2d2(x)
|
|
x = self.relu(x)
|
|
x = self.conv3d2(x)
|
|
x = self.bn3d2(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
# test train mode
|
|
m = M().train()
|
|
# currently we don't check if the module are configured with qconfig before fusion
|
|
# TODO: if we decide to do that in the future, this test needs to
|
|
# be updated
|
|
# train mode fuse_fx is called in prepare_qat_fx
|
|
m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),))
|
|
expected_nodes = [
|
|
ns.call_module(nni.ConvBn1d),
|
|
ns.call_module(nni.ConvBn2d),
|
|
ns.call_module(nni.ConvBn3d),
|
|
ns.call_module(nni.ConvBnReLU1d),
|
|
ns.call_module(nni.ConvBnReLU2d),
|
|
ns.call_module(nni.ConvBnReLU3d),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nn.ReLU): 0
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
# test eval mode
|
|
m = M().eval()
|
|
# fuse_fx is a top level api and only supports eval mode
|
|
m = fuse_fx(m)
|
|
expected_nodes = [
|
|
ns.call_module(nn.Conv1d),
|
|
ns.call_module(nn.Conv2d),
|
|
ns.call_module(nn.Conv3d),
|
|
ns.call_module(nni.ConvReLU1d),
|
|
ns.call_module(nni.ConvReLU2d),
|
|
ns.call_module(nni.ConvReLU3d),
|
|
]
|
|
# ConvBnRelu1d is not fused
|
|
expected_occurrence = {
|
|
ns.call_module(nn.ReLU): 0
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_fuse_linear_bn_eval(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(1, 1)
|
|
self.bn1d = nn.BatchNorm1d(1)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.bn1d(x)
|
|
return x
|
|
|
|
# test eval mode
|
|
m = M().eval()
|
|
# fuse_fx is a top level api and only supports eval mode
|
|
m = fuse_fx(m)
|
|
expected_nodes = [
|
|
ns.call_module(nn.Linear),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nn.BatchNorm1d): 0,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
@skipIfNoONEDNN
|
|
def test_fuse_linear_bn_leaky_relu_onednn(self):
|
|
# linear - bn - leaky_relu is fused for onednn backend only
|
|
from torch.ao.quantization.backend_config import get_onednn_backend_config
|
|
expected_nodes = [
|
|
ns.call_module(nni.LinearLeakyReLU),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nn.BatchNorm1d): 0,
|
|
ns.call_module(nn.LeakyReLU): 0,
|
|
}
|
|
|
|
for with_bn in [True, False]:
|
|
# test eval mode
|
|
m = LinearBnLeakyReluModel(with_bn).eval()
|
|
# fuse_fx is a top level api and only supports eval mode
|
|
m = fuse_fx(m,
|
|
backend_config=get_onednn_backend_config())
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_linear_bn_leaky_relu_not_fused_by_default(self):
|
|
# Make sure linear - bn - leaky_relu is not fused by default
|
|
for with_bn in [True, False]:
|
|
# test eval mode
|
|
m = LinearBnLeakyReluModel(with_bn).eval()
|
|
# fuse_fx is a top level api and only supports eval mode
|
|
m = fuse_fx(m)
|
|
expected_nodes = [
|
|
ns.call_module(nn.Linear),
|
|
ns.call_module(nn.LeakyReLU),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nni.LinearLeakyReLU): 0,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
@skipIfNoONEDNN
|
|
def test_fuse_linear_tanh_for_onednn_backend(self):
|
|
# linear - tanh is fused for onednn backend only
|
|
from torch.ao.quantization.backend_config import get_onednn_backend_config
|
|
expected_nodes = [
|
|
ns.call_module(nni.LinearTanh),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nn.Linear): 0,
|
|
ns.call_module(nn.Tanh): 0,
|
|
}
|
|
|
|
# test eval mode
|
|
m = LinearTanhModel().eval()
|
|
# fuse_fx is a top level api and only supports eval mode
|
|
m = fuse_fx(m,
|
|
backend_config=get_onednn_backend_config())
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_linear_tanh_not_fused_by_default(self):
|
|
# Make sure linear - tanh is not fused by default
|
|
# test eval mode
|
|
m = LinearTanhModel().eval()
|
|
# fuse_fx is a top level api and only supports eval mode
|
|
m = fuse_fx(m)
|
|
expected_nodes = [
|
|
ns.call_module(nn.Linear),
|
|
ns.call_module(nn.Tanh),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nni.LinearTanh): 0,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_fuse_conv_bn_add_relu_onednn(self):
|
|
# conv - bn - add - relu is fused for onednn backend only
|
|
from torch.ao.quantization.backend_config import get_onednn_backend_config
|
|
options = itertools.product(
|
|
[True, False], # with_bn
|
|
[True, False], # with_relu
|
|
[True, False], # conv in the left
|
|
[True, False], # with_two_conv
|
|
[True, False], # use_torch_add
|
|
)
|
|
for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
|
|
expected_nodes = [
|
|
ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d): 1,
|
|
ns.call_module(nn.BatchNorm2d): 0,
|
|
}
|
|
|
|
# test eval mode
|
|
m = ConvBnAddReluModel(
|
|
with_bn=with_bn,
|
|
with_relu=with_relu,
|
|
left_conv=left_conv,
|
|
two_conv=two_conv,
|
|
use_torch_add=use_torch_add).eval()
|
|
|
|
m = fuse_fx(m,
|
|
backend_config=get_onednn_backend_config())
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_fuse_conv_bn_add_relu_by_default(self):
|
|
options = itertools.product(
|
|
[True, False], # with_bn
|
|
[True, False], # with_relu
|
|
[True, False], # conv in the left
|
|
[True, False], # with_two_conv
|
|
[True, False], # use_torch_add
|
|
)
|
|
for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
|
|
# test eval mode
|
|
expected_nodes = [
|
|
ns.call_module(nn.Conv2d),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nni.ConvAdd2d): 0,
|
|
}
|
|
m = ConvBnAddReluModel(
|
|
with_bn=with_bn,
|
|
with_relu=with_relu,
|
|
left_conv=left_conv,
|
|
two_conv=two_conv,
|
|
use_torch_add=use_torch_add).eval()
|
|
m = fuse_fx(m)
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
@skipIfNoONEDNN
|
|
def test_fuse_conv_bn_add_relu_lowering(self):
|
|
""" Test fusion and lowering of Conv2d - (bn -) ReLU
|
|
by FX. For onednn backedn only.
|
|
"""
|
|
from torch.ao.quantization.backend_config import get_onednn_backend_config
|
|
qconfig_mapping = get_default_qconfig_mapping('onednn')
|
|
with override_quantized_engine('onednn'):
|
|
options = itertools.product(
|
|
[True, False], # with_bn
|
|
[True, False], # with_relu
|
|
[True, False], # conv in the left
|
|
[True, False], # two_conv
|
|
[True, False], # use_torch_add
|
|
)
|
|
for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1 if two_conv else 2,
|
|
ns.call_method("dequantize"): 1,
|
|
ns.call_module(nniq.ConvAddReLU2d if with_relu else nniq.ConvAdd2d): 1,
|
|
ns.call_module(nn.Conv2d): 0,
|
|
ns.call_module(nn.ReLU): 0,
|
|
}
|
|
node_occurrence_ref = {
|
|
ns.call_function(torch.quantize_per_tensor): 3,
|
|
ns.call_method("dequantize"): 3,
|
|
}
|
|
|
|
# test eval mode
|
|
m = ConvBnAddReluModel(
|
|
with_bn=with_bn,
|
|
with_relu=with_relu,
|
|
left_conv=left_conv,
|
|
two_conv=two_conv,
|
|
use_torch_add=use_torch_add).eval()
|
|
example_x = m.get_example_inputs()
|
|
m = prepare_fx(m, qconfig_mapping,
|
|
example_inputs=example_x,
|
|
backend_config=get_onednn_backend_config())
|
|
m_copy = copy.deepcopy(m)
|
|
m = convert_fx(m, backend_config=get_onednn_backend_config())
|
|
m_ref = convert_to_reference_fx(m_copy)
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
|
|
m(*example_x)
|
|
|
|
def test_fuse_convtranspose_bn_eval(self):
|
|
|
|
m = ModelForConvTransposeBNFusion().eval()
|
|
m = fuse_fx(m)
|
|
|
|
expected_nodes = [
|
|
ns.call_module(nn.ConvTranspose1d),
|
|
ns.call_module(nn.ConvTranspose2d),
|
|
ns.call_module(nn.ConvTranspose3d),
|
|
]
|
|
expected_occurrence = {
|
|
ns.call_module(nn.BatchNorm1d): 0,
|
|
ns.call_module(nn.BatchNorm2d): 0,
|
|
ns.call_module(nn.BatchNorm3d): 0,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=expected_nodes,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
|
|
def test_fuse_module_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1d = nn.Conv1d(1, 1, 1)
|
|
self.conv2d = nn.Conv2d(1, 1, 1)
|
|
self.conv3d = nn.Conv3d(1, 1, 1)
|
|
self.bn1d = nn.BatchNorm1d(1)
|
|
self.bn2d = nn.BatchNorm2d(1)
|
|
self.bn3d = nn.BatchNorm3d(1)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv1d(x)
|
|
x = self.relu(x)
|
|
x = self.conv2d(x)
|
|
x = self.relu(x)
|
|
x = self.conv3d(x)
|
|
x = self.relu(x)
|
|
x = self.bn1d(x)
|
|
x = self.relu(x)
|
|
x = self.bn2d(x)
|
|
x = self.relu(x)
|
|
x = self.bn3d(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
m = fuse_fx(m)
|
|
expected_nodes = [
|
|
ns.call_module(nni.ConvReLU1d),
|
|
ns.call_module(nni.ConvReLU2d),
|
|
ns.call_module(nni.ConvReLU3d),
|
|
ns.call_module(nni.BNReLU2d),
|
|
ns.call_module(nni.BNReLU3d),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=expected_nodes)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qconfig_fused_module(self):
|
|
""" TODO: add test for all fused modules
|
|
"""
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [(nn.Linear, default_qconfig),
|
|
(nn.ReLU, default_qconfig),
|
|
(F.relu, default_qconfig)]
|
|
}
|
|
|
|
linearRelu_node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nniq.LinearReLU),
|
|
ns.call_method('dequantize')
|
|
]
|
|
|
|
linearReluLinear_node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nniq.LinearReLU),
|
|
ns.call_module(nnq.Linear),
|
|
ns.call_method('dequantize')
|
|
]
|
|
|
|
tests = [(LinearReluModel, linearRelu_node_list),
|
|
(LinearReluLinearModel, linearReluLinear_node_list)]
|
|
|
|
for M, node_list in tests:
|
|
m = M().eval()
|
|
example_inputs = (torch.rand(5, 5),)
|
|
prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
|
|
prepared(*example_inputs)
|
|
quantized = convert_fx(prepared)
|
|
|
|
self.checkGraphModuleNodes(quantized, expected_node_list=node_list)
|
|
|
|
def test_problematic_fuse_example(self):
|
|
class LinearRelu(nn.Sequential):
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
nn.Linear(5, 5),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lin_relu = LinearRelu()
|
|
self.linear = nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.lin_relu(x)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
# these qconfigs somehow fail equality where default_qconfig does not
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [
|
|
(torch.nn.Linear, get_default_qconfig('fbgemm')),
|
|
(torch.nn.ReLU, get_default_qconfig('fbgemm')),
|
|
],
|
|
}
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))
|
|
|
|
self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.ao.nn.intrinsic.modules.fused.LinearReLU))
|
|
|
|
@unittest.skip("Temporarily skipping the test case, will enable after the simple"
|
|
"pattern format is supported")
|
|
def test_fuse_addtional_fuser_method(self):
|
|
class MyConvReLU(torch.nn.Module):
|
|
pass
|
|
|
|
def my_conv_relu_fuser(conv, relu):
|
|
return MyConvReLU()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.conv(x))
|
|
|
|
m = M().eval()
|
|
m = fuse_fx(m, fuse_custom_config={
|
|
"additional_fuser_method_mapping": {
|
|
(torch.nn.Conv2d, torch.nn.ReLU): my_conv_relu_fuser
|
|
}
|
|
})
|
|
self.checkGraphModuleNodes(m, expected_node=ns.call_module(MyConvReLU))
|
|
|
|
def test_fuse_custom_pattern(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_torch_add=True):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
self.relu = torch.nn.ReLU()
|
|
self.maxpool = torch.nn.MaxPool2d(3)
|
|
if use_torch_add:
|
|
self.add = torch.add
|
|
else:
|
|
self.add = operator.add
|
|
|
|
def forward(self, x):
|
|
y = x
|
|
y = self.maxpool(x)
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.add(y, x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
for use_torch_add in [True, False]:
|
|
m = M(use_torch_add).eval()
|
|
|
|
def fuse_conv_bn_relu(is_qat, relu, add_pattern):
|
|
_, _, bn_pattern = add_pattern
|
|
bn, conv = bn_pattern
|
|
return conv
|
|
|
|
conv_bn_res_relu_config1 = BackendPatternConfig() \
|
|
._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
|
|
.set_fuser_method(fuse_conv_bn_relu)
|
|
conv_bn_res_relu_config2 = BackendPatternConfig() \
|
|
._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
|
|
.set_fuser_method(fuse_conv_bn_relu)
|
|
backend_config = BackendConfig() \
|
|
.set_backend_pattern_config(conv_bn_res_relu_config1) \
|
|
.set_backend_pattern_config(conv_bn_res_relu_config2)
|
|
m = fuse_fx(m, backend_config=backend_config)
|
|
self.assertEqual(type(m.conv), torch.nn.Conv2d)
|
|
# check bn and relu are gone since we replaced the whole pattern to conv
|
|
self.assertFalse(hasattr(m, "bn"))
|
|
self.assertFalse(hasattr(m, "relu"))
|
|
|
|
def test_fusion_pattern_with_multiple_inputs(self):
|
|
""" This test tests two keys in backend_config: root_node_getter and
|
|
extra_inputs_getter,
|
|
root_node_getter is used to identify a "root" module in the node pattern,
|
|
the node that we'll keep after fusion.
|
|
extra_inputs_getter will return a list of node that needs to be added to the
|
|
fused node as extra inputs.
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
self.relu = torch.nn.ReLU()
|
|
self.maxpool = torch.nn.MaxPool2d(3)
|
|
|
|
def forward(self, x):
|
|
y = x
|
|
y = self.maxpool(x)
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = torch.add(x, y)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
|
|
def fuse_conv_bn_relu(is_qat, relu, add_pattern):
|
|
_, bn_pattern, _ = add_pattern
|
|
bn, conv = bn_pattern
|
|
return conv
|
|
|
|
def conv_bn_res_relu_root_node_getter(pattern):
|
|
relu, add_pattern = pattern
|
|
_, bn_pattern, _ = add_pattern
|
|
bn, conv = bn_pattern
|
|
return conv
|
|
|
|
def conv_bn_res_relu_extra_inputs_getter(pattern):
|
|
""" get inputs pattern for extra inputs, inputs for root node
|
|
are assumed to be copied over from root node to the fused node
|
|
"""
|
|
relu, add_pattern = pattern
|
|
_, bn_pattern, extra_input = add_pattern
|
|
bn, conv = bn_pattern
|
|
return [extra_input]
|
|
|
|
conv_bn_res_relu_config = BackendPatternConfig() \
|
|
._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
|
|
.set_fuser_method(fuse_conv_bn_relu) \
|
|
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
|
|
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
|
|
backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config)
|
|
m = fuse_fx(m, backend_config=backend_config)
|
|
self.assertEqual(type(m.conv), torch.nn.Conv2d)
|
|
# check bn and relu are gone since we replaced the whole pattern to conv
|
|
self.assertFalse(hasattr(m, "bn"))
|
|
self.assertFalse(hasattr(m, "relu"))
|
|
|
|
# check conv module has two inputs
|
|
named_modules = dict(m.named_modules())
|
|
for node in m.graph.nodes:
|
|
if node.op == "call_module" and type(named_modules[node.target]) is torch.nn.Conv2d:
|
|
self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments")
|
|
|
|
def test_fusion_pattern_with_matchallnode(self):
|
|
"""This test tests that the node matched by MatchAllNode will be regared as an input
|
|
instead of a module to be fused. For instance, we have two patterns:
|
|
(nn.ReLU, (torch.add, MatchAllNode, nn.Conv2d))
|
|
(nn.ReLU, nn.Conv2d)
|
|
And we wanna fuse the following model
|
|
Conv2d -> ReLU +
|
|
Conv2d ------ Add -> ReLU
|
|
ReLU in the first row is matched as MatchAllNode in the residual pattern. But it won't be
|
|
fused as part of that pattnern. It needs to be properly fused with the upstream Conv2d.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
|
self.relu1 = torch.nn.ReLU()
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
|
self.relu2 = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
y = self.conv1(x)
|
|
y = self.relu1(y)
|
|
|
|
x = self.conv2(x)
|
|
x = torch.add(x, y)
|
|
x = self.relu2(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
|
|
def fuse_conv_relu(is_qat, conv, relu):
|
|
return conv
|
|
|
|
def fuse_conv_res_relu(is_qat, relu, add_pattern):
|
|
_, conv, _ = add_pattern
|
|
return conv
|
|
|
|
def conv_res_relu_root_node_getter(pattern):
|
|
relu, (_, conv, _) = pattern
|
|
return conv
|
|
|
|
def conv_res_relu_extra_inputs_getter(pattern):
|
|
relu, (_, _, extra_input) = pattern
|
|
return [extra_input]
|
|
|
|
conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \
|
|
.set_fuser_method(fuse_conv_relu)
|
|
conv_res_relu_config = BackendPatternConfig() \
|
|
._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \
|
|
.set_fuser_method(fuse_conv_res_relu) \
|
|
._set_root_node_getter(conv_res_relu_root_node_getter) \
|
|
._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter)
|
|
backend_config = BackendConfig() \
|
|
.set_backend_pattern_config(conv_relu_config) \
|
|
.set_backend_pattern_config(conv_res_relu_config)
|
|
m = fuse_fx(m, backend_config=backend_config)
|
|
self.assertEqual(type(m.conv1), torch.nn.Conv2d)
|
|
self.assertEqual(type(m.conv2), torch.nn.Conv2d)
|
|
# check relu are gone since we replaced both patterns to conv
|
|
self.assertFalse(hasattr(m, "relu1"))
|
|
self.assertFalse(hasattr(m, "relu2"))
|
|
|
|
|
|
@skipIfNoFBGEMM
|
|
class TestQuantizeFx(QuantizationTestCase):
|
|
def test_pattern_match(self):
|
|
""" test MatchAllNode with
|
|
conv - bn - add - relu pattern
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 1, 1)
|
|
self.bn = nn.BatchNorm2d(1)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = x + y
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))
|
|
m = torch.fx.symbolic_trace(M())
|
|
modules = dict(m.named_modules())
|
|
for n in m.graph.nodes:
|
|
if n.op == 'call_module' and type(modules[n.target]) is nn.ReLU:
|
|
self.assertTrue(_is_match(modules, n, pattern))
|
|
|
|
def test_pattern_match_constant(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
x, _ = torch.ops.aten.max_pool2d_with_indices.default(x)
|
|
return x
|
|
|
|
pattern = (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0)
|
|
m = torch.fx.symbolic_trace(M())
|
|
# eliminate the code that get the second output of maxpool, so that the pattern
|
|
# can be matched
|
|
m.graph.eliminate_dead_code()
|
|
modules = dict(m.named_modules())
|
|
for n in m.graph.nodes:
|
|
if n.op == "call_function" and n.target == operator.getitem:
|
|
self.assertTrue(_is_match(modules, n, pattern))
|
|
|
|
def test_fused_module_qat_swap(self):
|
|
class Tmp(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.tmp = torch.nn.Linear(5, 5)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.tmp(x)
|
|
return self.relu(x)
|
|
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(Tmp(), torch.nn.Linear(5, 5))
|
|
self.mods2 = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
a = self.mods1(x)
|
|
x = torch.add(x, 5)
|
|
x = self.mods2(x)
|
|
x = torch.add(x, 5)
|
|
return a, x
|
|
|
|
|
|
model = M().train()
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [
|
|
(torch.nn.Linear, default_qat_qconfig),
|
|
(torch.nn.ReLU, default_qat_qconfig),
|
|
],
|
|
}
|
|
prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))
|
|
self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU))
|
|
|
|
def _get_conv_linear_test_cases(self, is_reference):
|
|
""" Returns a list of test cases, with format:
|
|
is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_op
|
|
"""
|
|
class FunctionalConv1d(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
self.stride = 1
|
|
self.padding = 0
|
|
self.dilation = 1
|
|
self.groups = 1
|
|
|
|
def forward(self, x):
|
|
return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
|
|
class Conv1d(torch.nn.Module):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv1d(*args)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
conv1d_input = torch.rand(1, 3, 224)
|
|
conv1d_weight = torch.rand(3, 3, 3)
|
|
conv1d_module_args = (3, 3, 3)
|
|
|
|
class FunctionalConv2d(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
self.stride = (1, 1)
|
|
self.padding = (0, 0)
|
|
self.dilation = (1, 1)
|
|
self.groups = 1
|
|
|
|
def forward(self, x):
|
|
return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
class Conv2d(torch.nn.Module):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(*args)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
conv2d_input = torch.rand(1, 3, 224, 224)
|
|
conv2d_weight = torch.rand(3, 3, 3, 3)
|
|
conv2d_module_args = (3, 3, 3)
|
|
|
|
class FunctionalConv3d(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
self.stride = (1, 1, 1)
|
|
self.padding = (0, 0, 0)
|
|
self.dilation = (1, 1, 1)
|
|
self.groups = 1
|
|
|
|
def forward(self, x):
|
|
return F.conv3d(
|
|
x,
|
|
self.weight,
|
|
None,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
|
|
class Conv3d(torch.nn.Module):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv3d(*args)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
conv3d_input = torch.rand(1, 3, 32, 224, 224)
|
|
conv3d_weight = torch.rand(3, 3, 3, 3, 3)
|
|
conv3d_module_args = (3, 3, 3)
|
|
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
|
|
def forward(self, x):
|
|
return F.linear(x, self.weight)
|
|
|
|
linear_input = torch.rand(8, 5)
|
|
linear_weight = torch.rand(10, 5)
|
|
|
|
class LinearModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
linear_module_input = torch.rand(8, 5)
|
|
|
|
# is_dynamic, ModuleClass, module_constructor_inputs,
|
|
# inputs, quantized_node, weight_prepack_node
|
|
tests = [
|
|
(
|
|
False,
|
|
FunctionalConv1d,
|
|
(conv1d_weight,),
|
|
(conv1d_input,),
|
|
ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) ,
|
|
ns.call_function(torch.ops.quantized.conv1d_prepack),
|
|
),
|
|
(
|
|
False,
|
|
FunctionalConv2d,
|
|
(conv2d_weight,),
|
|
(conv2d_input,),
|
|
ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d),
|
|
ns.call_function(torch.ops.quantized.conv2d_prepack),
|
|
),
|
|
(
|
|
False,
|
|
FunctionalConv3d,
|
|
(conv3d_weight,),
|
|
(conv3d_input,),
|
|
ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d),
|
|
ns.call_function(torch.ops.quantized.conv3d_prepack),
|
|
),
|
|
(
|
|
False,
|
|
Conv1d,
|
|
conv1d_module_args,
|
|
(conv1d_input,),
|
|
ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d),
|
|
None
|
|
),
|
|
(
|
|
False,
|
|
Conv2d,
|
|
conv2d_module_args,
|
|
(conv2d_input,),
|
|
ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d),
|
|
None
|
|
),
|
|
(
|
|
False,
|
|
Conv3d,
|
|
conv3d_module_args,
|
|
(conv3d_input,),
|
|
ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d),
|
|
None
|
|
),
|
|
(
|
|
True,
|
|
Linear,
|
|
(linear_weight,),
|
|
(linear_input,),
|
|
None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic),
|
|
ns.call_function(torch.ops.quantized.linear_prepack),
|
|
),
|
|
(
|
|
False,
|
|
Linear,
|
|
(linear_weight,),
|
|
(linear_input,),
|
|
ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear),
|
|
ns.call_function(torch.ops.quantized.linear_prepack),
|
|
),
|
|
(
|
|
True,
|
|
LinearModule,
|
|
(),
|
|
(linear_module_input,),
|
|
ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear),
|
|
None,
|
|
),
|
|
(
|
|
False,
|
|
LinearModule,
|
|
(),
|
|
(linear_module_input,),
|
|
ns.call_module(nnqr.Linear if is_reference else nnq.Linear),
|
|
None,
|
|
),
|
|
]
|
|
return tests
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_conv_linear_not_reference(self):
|
|
""" Test quantizing conv and linear
|
|
"""
|
|
tests = self._get_conv_linear_test_cases(is_reference=False)
|
|
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_node) in tests:
|
|
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
|
node_occurrence = {}
|
|
if weight_prepack_node:
|
|
node_occurrence[weight_prepack_node] = 0
|
|
self.checkGraphModeFxOp(
|
|
ModuleClass(*module_constructor_inputs),
|
|
inputs, quant_type,
|
|
expected_node=quantized_node,
|
|
expected_node_occurrence=node_occurrence,
|
|
is_reference=False)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_conv_linear_reference(self):
|
|
""" Test quantizing functional conv and linear with reference option
|
|
"""
|
|
tests = self._get_conv_linear_test_cases(is_reference=True)
|
|
|
|
def _get_keys(prefix, is_dynamic):
|
|
all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
|
|
if not is_dynamic:
|
|
all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
|
|
return all_keys
|
|
|
|
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_node) in tests:
|
|
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
|
node_occurrence = {}
|
|
if weight_prepack_node:
|
|
node_occurrence[weight_prepack_node] = 0
|
|
result_dict = self.checkGraphModeFxOp(
|
|
ModuleClass(*module_constructor_inputs),
|
|
inputs, quant_type,
|
|
expected_node=quantized_node,
|
|
expected_node_occurrence=node_occurrence,
|
|
is_reference=True)
|
|
qr = result_dict["quantized_reference"]
|
|
|
|
def checkWeightQParams(model):
|
|
for module_name in ("linear", "conv"):
|
|
if hasattr(model, module_name):
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
|
|
self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
|
|
|
|
def checkSerDeser(model, is_dynamic):
|
|
for module_name in ("linear", "conv"):
|
|
if hasattr(model, module_name):
|
|
# make sure seralization works
|
|
state_dict = copy.deepcopy(model.state_dict())
|
|
all_keys = _get_keys(module_name, is_dynamic)
|
|
for key in all_keys:
|
|
self.assertTrue(key in state_dict)
|
|
# check load_state_dict restores states
|
|
module = getattr(model, module_name)
|
|
prev_scale = module.weight_scale
|
|
module.weight_scale = None
|
|
model.load_state_dict(state_dict)
|
|
module = getattr(model, module_name)
|
|
self.assertTrue(torch.equal(prev_scale, module.weight_scale))
|
|
|
|
|
|
checkWeightQParams(qr)
|
|
qr = copy.deepcopy(qr)
|
|
# make sure the qparams are preserved after copy
|
|
checkWeightQParams(qr)
|
|
|
|
checkSerDeser(qr, is_dynamic)
|
|
|
|
def _get_conv_transpose_test_cases(self, use_relu, is_reference):
|
|
""" Returns a list of test cases, with format:
|
|
is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_op
|
|
"""
|
|
class FunctionalConvTranspose1d(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
self.stride = 1
|
|
self.padding = 0
|
|
self.output_padding = 0
|
|
self.dilation = 1
|
|
self.groups = 1
|
|
|
|
def forward(self, x):
|
|
y = F.conv_transpose1d(
|
|
x,
|
|
self.weight,
|
|
None,
|
|
self.stride,
|
|
self.padding,
|
|
self.output_padding,
|
|
self.groups,
|
|
self.dilation
|
|
)
|
|
if use_relu:
|
|
y = F.relu(y)
|
|
return y
|
|
|
|
class ConvTranspose1d(torch.nn.Module):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.deconv = torch.nn.ConvTranspose1d(*args)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
y = self.deconv(x)
|
|
if use_relu:
|
|
y = self.relu(y)
|
|
return y
|
|
|
|
conv_transpose1d_input = torch.rand(1, 3, 224)
|
|
conv_transpose1d_weight = torch.rand(3, 3, 3)
|
|
conv_transpose1d_module_args = (3, 3, 3)
|
|
|
|
class FunctionalConvTranspose2d(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
self.stride = (1, 1)
|
|
self.padding = (0, 0)
|
|
self.output_padding = (0, 0)
|
|
self.dilation = (1, 1)
|
|
self.groups = 1
|
|
|
|
def forward(self, x):
|
|
y = F.conv_transpose2d(
|
|
x,
|
|
self.weight,
|
|
None,
|
|
self.stride,
|
|
self.padding,
|
|
self.output_padding,
|
|
self.groups,
|
|
self.dilation
|
|
)
|
|
if use_relu:
|
|
y = F.relu(y)
|
|
return y
|
|
|
|
class ConvTranspose2d(torch.nn.Module):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.deconv = torch.nn.ConvTranspose2d(*args)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
y = self.deconv(x)
|
|
if use_relu:
|
|
y = self.relu(y)
|
|
return y
|
|
|
|
conv_transpose2d_input = torch.rand(1, 3, 224, 224)
|
|
conv_transpose2d_weight = torch.rand(3, 3, 3, 3)
|
|
conv_transpose2d_module_args = (3, 3, 3)
|
|
|
|
class FunctionalConvTranspose3d(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
self.stride = (1, 1, 1)
|
|
self.padding = (0, 0, 0)
|
|
self.output_padding = (0, 0, 0)
|
|
self.dilation = (1, 1, 1)
|
|
self.groups = 1
|
|
|
|
def forward(self, x):
|
|
y = F.conv_transpose3d(
|
|
x,
|
|
self.weight,
|
|
None,
|
|
self.stride,
|
|
self.padding,
|
|
self.output_padding,
|
|
self.groups,
|
|
self.dilation
|
|
)
|
|
if use_relu:
|
|
y = F.relu(y)
|
|
return y
|
|
|
|
class ConvTranspose3d(torch.nn.Module):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
self.deconv = torch.nn.ConvTranspose3d(*args)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
y = self.deconv(x)
|
|
if use_relu:
|
|
y = self.relu(y)
|
|
return y
|
|
|
|
conv_transpose3d_input = torch.rand(1, 3, 32, 224, 224)
|
|
conv_transpose3d_weight = torch.rand(3, 3, 3, 3, 3)
|
|
conv_transpose3d_module_args = (3, 3, 3)
|
|
|
|
# is_dynamic, ModuleClass, module_constructor_inputs,
|
|
# inputs, quantized_node, weight_prepack_node
|
|
tests = [
|
|
(
|
|
False,
|
|
FunctionalConvTranspose1d,
|
|
(conv_transpose1d_weight,),
|
|
(conv_transpose1d_input,),
|
|
ns.call_function(
|
|
torch.nn.functional.conv_transpose1d if is_reference else torch.ops.quantized.conv_transpose1d
|
|
),
|
|
ns.call_function(torch.ops.quantized.conv_transpose1d_prepack),
|
|
),
|
|
(
|
|
False,
|
|
FunctionalConvTranspose2d,
|
|
(conv_transpose2d_weight,),
|
|
(conv_transpose2d_input,),
|
|
ns.call_function(
|
|
torch.nn.functional.conv_transpose2d if is_reference else torch.ops.quantized.conv_transpose2d
|
|
),
|
|
ns.call_function(torch.ops.quantized.conv_transpose2d_prepack),
|
|
),
|
|
(
|
|
False,
|
|
FunctionalConvTranspose3d,
|
|
(conv_transpose3d_weight,),
|
|
(conv_transpose3d_input,),
|
|
ns.call_function(
|
|
torch.nn.functional.conv_transpose3d if is_reference else torch.ops.quantized.conv_transpose3d),
|
|
ns.call_function(torch.ops.quantized.conv_transpose3d_prepack),
|
|
),
|
|
(
|
|
False,
|
|
ConvTranspose1d,
|
|
conv_transpose1d_module_args,
|
|
(conv_transpose1d_input,),
|
|
ns.call_module(nnqr.ConvTranspose1d if is_reference else nnq.ConvTranspose1d),
|
|
None
|
|
),
|
|
(
|
|
False,
|
|
ConvTranspose2d,
|
|
conv_transpose2d_module_args,
|
|
(conv_transpose2d_input,),
|
|
ns.call_module(nnqr.ConvTranspose2d if is_reference else nnq.ConvTranspose2d),
|
|
None
|
|
),
|
|
(
|
|
False,
|
|
ConvTranspose3d,
|
|
conv_transpose3d_module_args,
|
|
(conv_transpose3d_input,),
|
|
ns.call_module(nnqr.ConvTranspose3d if is_reference else nnq.ConvTranspose3d),
|
|
None
|
|
),
|
|
]
|
|
return tests
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_conv_transpose_not_reference(self):
|
|
""" Test quantizing transposed conv
|
|
"""
|
|
tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=False)
|
|
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_node) in tests:
|
|
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
|
node_occurrence = {}
|
|
if weight_prepack_node:
|
|
node_occurrence[weight_prepack_node] = 0
|
|
self.checkGraphModeFxOp(
|
|
ModuleClass(*module_constructor_inputs),
|
|
inputs, quant_type,
|
|
expected_node=quantized_node,
|
|
expected_node_occurrence=node_occurrence,
|
|
is_reference=False)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_conv_transpose_reference(self):
|
|
""" Test quantizing transposed conv with reference option
|
|
"""
|
|
tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=True)
|
|
|
|
def _get_keys(prefix, is_dynamic):
|
|
all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
|
|
if not is_dynamic:
|
|
all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
|
|
return all_keys
|
|
|
|
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_node) in tests:
|
|
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
|
node_occurrence = {}
|
|
if weight_prepack_node:
|
|
node_occurrence[weight_prepack_node] = 0
|
|
result_dict = self.checkGraphModeFxOp(
|
|
ModuleClass(*module_constructor_inputs),
|
|
inputs, quant_type,
|
|
expected_node=quantized_node,
|
|
expected_node_occurrence=node_occurrence,
|
|
is_reference=True)
|
|
qr = result_dict["quantized_reference"]
|
|
|
|
def checkWeightQParams(model):
|
|
module_name = "deconv"
|
|
if hasattr(model, module_name):
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
|
|
self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
|
|
|
|
def checkSerDeser(model, is_dynamic):
|
|
module_name = "deconv"
|
|
if hasattr(model, module_name):
|
|
# make sure seralization works
|
|
state_dict = copy.deepcopy(model.state_dict())
|
|
all_keys = _get_keys(module_name, is_dynamic)
|
|
for key in all_keys:
|
|
self.assertTrue(key in state_dict)
|
|
# check load_state_dict restores states
|
|
module = getattr(model, module_name)
|
|
prev_scale = module.weight_scale
|
|
module.weight_scale = None
|
|
model.load_state_dict(state_dict)
|
|
module = getattr(model, module_name)
|
|
self.assertTrue(torch.equal(prev_scale, module.weight_scale))
|
|
|
|
|
|
checkWeightQParams(qr)
|
|
qr = copy.deepcopy(qr)
|
|
# make sure the qparams are preserved after copy
|
|
checkWeightQParams(qr)
|
|
|
|
checkSerDeser(qr, is_dynamic)
|
|
|
|
def test_conv_transpose_relu_not_reference(self):
|
|
""" Test quantizing transposed conv + relu
|
|
Fusion with relu is not supported.
|
|
"""
|
|
tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=False)
|
|
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_node) in tests:
|
|
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
|
node_occurrence = {}
|
|
if weight_prepack_node:
|
|
node_occurrence[weight_prepack_node] = 0
|
|
if quantized_node.op == 'call_module':
|
|
node_occurrence[ns.call_module(nn.ReLU)] = 1
|
|
else:
|
|
node_occurrence[ns.call_function(F.relu)] = 1
|
|
self.checkGraphModeFxOp(
|
|
ModuleClass(*module_constructor_inputs),
|
|
inputs, quant_type,
|
|
expected_node=quantized_node,
|
|
expected_node_occurrence=node_occurrence,
|
|
is_reference=False)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_conv_transpose_relu_reference(self):
|
|
""" Test quantizing transposed conv with reference option
|
|
Fusion with relu is not supported.
|
|
"""
|
|
tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=True)
|
|
|
|
def _get_keys(prefix, is_dynamic):
|
|
all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
|
|
if not is_dynamic:
|
|
all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
|
|
return all_keys
|
|
|
|
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_node) in tests:
|
|
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
|
node_occurrence = {}
|
|
if weight_prepack_node:
|
|
node_occurrence[weight_prepack_node] = 0
|
|
if quantized_node.op == 'call_module':
|
|
node_occurrence[ns.call_module(nn.ReLU)] = 1
|
|
else:
|
|
node_occurrence[ns.call_function(F.relu)] = 1
|
|
result_dict = self.checkGraphModeFxOp(
|
|
ModuleClass(*module_constructor_inputs),
|
|
inputs, quant_type,
|
|
expected_node=quantized_node,
|
|
expected_node_occurrence=node_occurrence,
|
|
is_reference=True)
|
|
qr = result_dict["quantized_reference"]
|
|
|
|
def checkWeightQParams(model):
|
|
module_name = "deconv"
|
|
if hasattr(model, module_name):
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
|
|
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
|
|
self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
|
|
|
|
def checkSerDeser(model, is_dynamic):
|
|
module_name = "deconv"
|
|
if hasattr(model, module_name):
|
|
# make sure seralization works
|
|
state_dict = copy.deepcopy(model.state_dict())
|
|
all_keys = _get_keys(module_name, is_dynamic)
|
|
for key in all_keys:
|
|
self.assertTrue(key in state_dict)
|
|
# check load_state_dict restores states
|
|
module = getattr(model, module_name)
|
|
prev_scale = module.weight_scale
|
|
module.weight_scale = None
|
|
model.load_state_dict(state_dict)
|
|
module = getattr(model, module_name)
|
|
self.assertTrue(torch.equal(prev_scale, module.weight_scale))
|
|
|
|
|
|
checkWeightQParams(qr)
|
|
qr = copy.deepcopy(qr)
|
|
# make sure the qparams are preserved after copy
|
|
checkWeightQParams(qr)
|
|
|
|
checkSerDeser(qr, is_dynamic)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_dynamic_quant_weight_observer(self):
|
|
''' Test that weight observer is run in convert step
|
|
'''
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
|
|
def forward(self, x):
|
|
return F.linear(x, self.weight)
|
|
|
|
m = M(torch.rand(1, 1)).eval()
|
|
qconfig = default_dynamic_qconfig
|
|
qconfig_dict = {'': qconfig}
|
|
example_inputs = (torch.rand(1, 1),)
|
|
prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
quantized = convert_to_reference_fx(prepared)
|
|
qparams = (quantized._scale_0, quantized._zero_point_0)
|
|
weight_obs = qconfig.weight()
|
|
weight_obs(quantized.weight)
|
|
# Get the actual value to avoid tensor size mismatch error, torch.Size([]) vs torch.Size([1])
|
|
ref_qparams = (weight_obs.calculate_qparams()[0].item(), weight_obs.calculate_qparams()[1].item())
|
|
self.assertEqual(qparams, ref_qparams)
|
|
|
|
def test_conv_bn_relu(self):
|
|
""" Tests fusion and quantization for "Conv - Bn" and "Conv - Bn - ReLU"
|
|
"""
|
|
convs = {
|
|
1: nn.Conv1d,
|
|
2: nn.Conv2d,
|
|
3: nn.Conv3d,
|
|
}
|
|
bns = {
|
|
1: nn.BatchNorm1d,
|
|
2: nn.BatchNorm2d,
|
|
3: nn.BatchNorm3d,
|
|
}
|
|
quantized_convs = {
|
|
1: nnq.Conv1d,
|
|
2: nnq.Conv2d,
|
|
3: nnq.Conv3d,
|
|
}
|
|
quantized_conv_relus = {
|
|
1: nniq.ConvReLU1d,
|
|
2: nniq.ConvReLU2d,
|
|
3: nniq.ConvReLU3d,
|
|
}
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim, has_relu):
|
|
super().__init__()
|
|
self.conv = convs[dim](3, 3, 3)
|
|
self.bn = bns[dim](3)
|
|
self.relu = nn.ReLU() if has_relu else nn.Identity()
|
|
self.has_relu = has_relu
|
|
self.quant = QuantStub()
|
|
self.dequant = DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
if self.has_relu:
|
|
x = self.relu(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
options = itertools.product([1, 2, 3], [True, False], self.static_quant_types)
|
|
for dim, has_relu, quant_type in options:
|
|
expected_node = ns.call_module(
|
|
quantized_conv_relus[dim] if has_relu
|
|
else quantized_convs[dim])
|
|
m = M(dim, has_relu)
|
|
m_eager = copy.deepcopy(m)
|
|
result_dict = self.checkGraphModeFxOp(
|
|
m,
|
|
self.img_data_dict[dim],
|
|
quant_type,
|
|
expected_node=expected_node,
|
|
)
|
|
result = result_dict["quantized_output"]
|
|
|
|
# check numerics
|
|
qengine = torch.backends.quantized.engine
|
|
if quant_type == QuantType.STATIC:
|
|
m_eager.eval()
|
|
qconfig = get_default_qconfig(qengine)
|
|
prepare_fn = prepare
|
|
is_qat = False
|
|
else:
|
|
m_eager.train()
|
|
qconfig = get_default_qat_qconfig(qengine)
|
|
prepare_fn = prepare_qat
|
|
is_qat = True
|
|
|
|
fuse_list = ["conv", "bn"]
|
|
if has_relu:
|
|
fuse_list.append("relu")
|
|
if is_qat:
|
|
fuse_modules_qat(m_eager, fuse_list, inplace=True)
|
|
else:
|
|
fuse_modules(m_eager, fuse_list, inplace=True)
|
|
m_eager.qconfig = qconfig
|
|
m_eager = prepare_fn(m_eager)
|
|
prepared_fx = result_dict["prepared"]
|
|
|
|
m_eager(*self.img_data_dict[dim][0])
|
|
m_eager = convert(m_eager)
|
|
result_eager = m_eager(*self.img_data_dict[dim][0])
|
|
self.assertEqual(result, result_eager)
|
|
|
|
def test_linear_bn(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(4, 4)
|
|
self.bn = nn.BatchNorm1d(4)
|
|
self.quant = QuantStub()
|
|
self.dequant = DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.linear(x)
|
|
x = self.bn(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
data = (torch.randn(4, 4),)
|
|
for quant_type in self.static_quant_types:
|
|
expected_node = ns.call_module(nnq.Linear)
|
|
m = M()
|
|
m_eager = copy.deepcopy(m)
|
|
result_dict = self.checkGraphModeFxOp(m, data, quant_type, expected_node=expected_node)
|
|
result = result_dict["quantized_output"]
|
|
|
|
# check numerics vs eager mode
|
|
fuse_list = ["linear", "bn"]
|
|
qengine = torch.backends.quantized.engine
|
|
if quant_type == QuantType.STATIC:
|
|
m_eager.eval()
|
|
qconfig = get_default_qconfig(qengine)
|
|
prepare_fn = prepare
|
|
fuse_modules(m_eager, fuse_list, inplace=True)
|
|
else:
|
|
m_eager.train()
|
|
qconfig = get_default_qat_qconfig(qengine)
|
|
prepare_fn = prepare_qat
|
|
fuse_modules_qat(m_eager, fuse_list, inplace=True)
|
|
m_eager.qconfig = qconfig
|
|
m_eager = prepare_fn(m_eager)
|
|
m_eager(*data)
|
|
m_eager = convert(m_eager)
|
|
result_eager = m_eager(*data)
|
|
self.assertEqual(result, result_eager)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_dynamic_quant_fp16(self):
|
|
with override_quantized_engine('fbgemm'):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
|
|
def forward(self, x):
|
|
return F.linear(x, self.weight)
|
|
|
|
linear_input = torch.rand(8, 5)
|
|
linear_weight = torch.rand(10, 5)
|
|
|
|
class LinearModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
linear_module_input = torch.rand(8, 5)
|
|
|
|
tests = [
|
|
(Linear, (linear_weight,), (linear_input,),
|
|
ns.call_function(torch.ops.quantized.linear_dynamic_fp16),
|
|
ns.call_function(torch.ops.quantized.linear_prepack_fp16)),
|
|
(LinearModule, (), (linear_module_input,),
|
|
ns.call_module(nnqd.Linear),
|
|
None),
|
|
]
|
|
for (ModuleClass, module_constructor_inputs,
|
|
inputs, quantized_node, weight_prepack_node) in tests:
|
|
for is_reference in [True, False]:
|
|
node_occurrence = {}
|
|
if weight_prepack_node:
|
|
node_occurrence[weight_prepack_node] = 0
|
|
m = ModuleClass(*module_constructor_inputs).eval()
|
|
qconfig_dict = {"": float16_dynamic_qconfig}
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=inputs)
|
|
convert_fn = convert_to_reference_fx if is_reference else convert_fx
|
|
m = convert_fn(m)
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@override_qengines
|
|
def test_qat_prepare_device_affinity(self):
|
|
"""
|
|
Tests that FX QAT prepare pass respects device affinity
|
|
"""
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 1, 1)
|
|
self.bn = nn.BatchNorm2d(1)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
model = Model()
|
|
qengine = torch.backends.quantized.engine
|
|
qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)}
|
|
device = torch.device('cuda:0')
|
|
model.to(device)
|
|
|
|
example_inputs = (torch.randn(4, 1, 4, 4, device=device),)
|
|
# QAT prepare
|
|
model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
|
|
# ensure that running an input on CUDA works without any needed changes
|
|
model(*example_inputs)
|
|
|
|
# ensure all buffers and parameters are on the device we expect
|
|
model_devices = {p.device for p in model.parameters()} | \
|
|
{p.device for p in model.buffers()}
|
|
self.assertEqual(len(model_devices), 1)
|
|
model_device = next(iter(model_devices))
|
|
self.assertEqual(model_device, device)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_dict_output(self):
|
|
""" Make sure quantization runs for models with dictionary output
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
return {"output": self.conv(x["input"])}
|
|
|
|
example_inputs = ({"input": torch.randn(1, 1, 1, 1)},)
|
|
m = M().eval()
|
|
qconfig_dict = {"": default_qconfig}
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
|
|
@override_qengines
|
|
def test_attention(self):
|
|
""" Make sure quantization runs for a corner case in attention module
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
q, k, v = x.chunk(3, dim=0)
|
|
q = q.contiguous().view(-1, 1).transpose(0, 1)
|
|
k = k.contiguous().view(-1, 1).transpose(0, 1)
|
|
v = v.contiguous().view(-1, 1).transpose(0, 1)
|
|
torch._assert(
|
|
k.size(1) == 1, "key size should be equal to 1"
|
|
)
|
|
r = torch.mm(k, v)
|
|
return q * k + r
|
|
|
|
example_inputs = (torch.randn(3, 1, 1, 1),)
|
|
m = M().eval()
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [
|
|
(nn.Conv2d, default_qconfig),
|
|
]
|
|
}
|
|
# make sure it runs
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
|
|
def _test_standalone_module(
|
|
self,
|
|
interface_config,
|
|
prepare_count_check,
|
|
standalone_prepare_count_check,
|
|
convert_count_check,
|
|
standalone_convert_count_check):
|
|
""" Test standalone module with different quantized input/quantized output
|
|
configurations
|
|
"""
|
|
class StandaloneModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
self.standalone = StandaloneModule()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.standalone(x)
|
|
return x
|
|
|
|
class RefM(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
# instantiate M and RefM and align the parameters
|
|
original_m = M().eval()
|
|
original_ref_m = RefM().eval()
|
|
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
|
|
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
|
|
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
|
|
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
|
|
|
|
for is_name in [True, False]:
|
|
sm_example_inputs = example_inputs
|
|
if is_name:
|
|
prepare_config = {
|
|
"standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)]
|
|
}
|
|
else:
|
|
prepare_config = {
|
|
"standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)]
|
|
}
|
|
|
|
original_m_copy = copy.deepcopy(original_m)
|
|
original_ref_m_copy = copy.deepcopy(original_ref_m)
|
|
|
|
qconfig_dict = {"": default_qconfig}
|
|
# check prepared model
|
|
m = prepare_fx(
|
|
original_m_copy,
|
|
qconfig_dict,
|
|
example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_config)
|
|
# calibration
|
|
m(*example_inputs)
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
|
|
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
|
|
|
|
# check converted/quantized model
|
|
m = convert_fx(m)
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
|
|
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
|
|
res = m(*example_inputs)
|
|
|
|
# quantize the reference model
|
|
ref_m = prepare_fx(
|
|
original_ref_m_copy,
|
|
qconfig_dict,
|
|
example_inputs=example_inputs,
|
|
)
|
|
ref_m(*example_inputs)
|
|
ref_m = convert_fx(ref_m)
|
|
ref_res = ref_m(*example_inputs)
|
|
self.assertEqual(res, ref_res)
|
|
|
|
def test_standalone_module_float_interface(self):
|
|
float_interface_config = {
|
|
"input_quantized_idxs": [], # float input
|
|
"output_quantized_idxs": [], # float output
|
|
}
|
|
interface_config = float_interface_config
|
|
# input and output of first conv, observer for standalone module
|
|
# will be inserted in the standalone module itself
|
|
prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 2
|
|
}
|
|
# for input and output of conv in the standalone module
|
|
standalone_prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 2
|
|
}
|
|
convert_count_check = {
|
|
ns.call_function(torch.quantize_per_tensor) : 1,
|
|
ns.call_module(nnq.Conv2d) : 1,
|
|
ns.call_method("dequantize") : 1,
|
|
}
|
|
standalone_convert_count_check = {
|
|
# standalone module will take float as input and output
|
|
# so we'll see quantize and dequantize in the modoule
|
|
ns.call_function(torch.quantize_per_tensor) : 1,
|
|
ns.call_module(nnq.Conv2d): 1,
|
|
ns.call_method("dequantize") : 1,
|
|
}
|
|
self._test_standalone_module(
|
|
interface_config,
|
|
prepare_count_check,
|
|
standalone_prepare_count_check,
|
|
convert_count_check,
|
|
standalone_convert_count_check)
|
|
|
|
def test_standalone_module_quantized_interface(self):
|
|
quantized_interface_config = {
|
|
"input_quantized_idxs": [0], # quantized input
|
|
"output_quantized_idxs": [0], # quantized output
|
|
}
|
|
interface_config = quantized_interface_config
|
|
# observer for input and output of first conv
|
|
prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 2
|
|
}
|
|
# for output of conv in the standalone module
|
|
standalone_prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 1
|
|
}
|
|
convert_count_check = {
|
|
# quantizing input for conv
|
|
ns.call_function(torch.quantize_per_tensor) : 1,
|
|
ns.call_module(nnq.Conv2d) : 1,
|
|
# dequantizing output of standalone module
|
|
ns.call_method("dequantize") : 1,
|
|
}
|
|
standalone_convert_count_check = {
|
|
# quantization of input happens in parent module
|
|
# quantization of output happens in the quantized conv module
|
|
ns.call_function(torch.quantize_per_tensor) : 0,
|
|
ns.call_module(nnq.Conv2d): 1,
|
|
# dequantization for output happens in parent module
|
|
ns.call_method("dequantize") : 0,
|
|
}
|
|
self._test_standalone_module(
|
|
interface_config,
|
|
prepare_count_check,
|
|
standalone_prepare_count_check,
|
|
convert_count_check,
|
|
standalone_convert_count_check)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qconfig_none(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 1, 1)
|
|
self.conv2 = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {"": default_qconfig,
|
|
"module_name": [("conv2", None)]}
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
# first conv is quantized, second conv is not quantized
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(nn.Conv2d),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
def test_qconfig_module_type(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 1, 1)
|
|
self.linear = nn.Linear(9, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = x.reshape((1, -1))
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
|
|
example_inputs = (torch.randn(1, 1, 3, 3),)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
# conv is quantized, linear is not quantized
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(nn.Linear),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
def test_qconfig_qat_module_type(self):
|
|
class LinearRelu(nn.Sequential):
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
nn.Linear(5, 5),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lin_relu = LinearRelu()
|
|
self.linear = nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.lin_relu(x)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
model = M().train()
|
|
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [
|
|
(torch.nn.Linear, default_qat_qconfig),
|
|
(torch.nn.ReLU, default_qat_qconfig),
|
|
],
|
|
}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nniq.LinearReLU),
|
|
ns.call_module(nnq.Linear),
|
|
ns.call_method("dequantize"),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
def test_qconfig_function(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {"object_type": [(operator.add, default_qconfig)]}
|
|
data = torch.randn(1, 1, 1, 1)
|
|
example_inputs = (data, data)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
# first conv is quantized, second conv is not quantized
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_method("dequantize"),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
def test_qconfig_module_name_regex(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 1, 1)
|
|
self.conv2 = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]}
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
# first conv is quantized, second conv is not quantized
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("dequantize"),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
def test_qconfig_precedence(self):
|
|
for device in get_supported_device_types():
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(1, 1)
|
|
self.conv = nn.Conv2d(1, 1, 1)
|
|
self.module_conv1 = nn.Conv2d(1, 1, 1)
|
|
self.module_conv2 = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
# global
|
|
x = self.linear(x)
|
|
# global + object_type --> object_type
|
|
x = self.conv(x)
|
|
# global + object_type + module_name_regex --> module_name_regex
|
|
x = self.module_conv1(x)
|
|
# global + object_type + module_name_regex + module_name --> module_name
|
|
x = self.module_conv2(x)
|
|
return x
|
|
|
|
m = M().to(device).eval()
|
|
|
|
global_qconfig = default_qconfig
|
|
object_type_qconfig = default_dynamic_qconfig
|
|
module_name_regex_qconfig = float16_dynamic_qconfig
|
|
module_name_qconfig = default_qat_qconfig
|
|
qconfig_dict = {
|
|
"": global_qconfig,
|
|
"object_type": [(nn.Conv2d, object_type_qconfig)],
|
|
"module_name_regex": [("module_conv*", module_name_regex_qconfig)],
|
|
"module_name": [("module_conv2", module_name_qconfig)]}
|
|
m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),))
|
|
self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func)
|
|
self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func)
|
|
self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func)
|
|
self.assertEqual(m_prep.conv.qconfig.weight.p.func, object_type_qconfig.weight.p.func)
|
|
self.assertEqual(m_prep.module_conv1.qconfig.activation.p.func, module_name_regex_qconfig.activation.p.func)
|
|
self.assertEqual(m_prep.module_conv1.qconfig.weight.p.func, module_name_regex_qconfig.weight.p.func)
|
|
self.assertEqual(m_prep.module_conv2.qconfig.activation.p.func, module_name_qconfig.activation.p.func)
|
|
self.assertEqual(m_prep.module_conv2.qconfig.weight.p.func, module_name_qconfig.weight.p.func)
|
|
|
|
def test_qconfig_module_name_object_type_order(self):
|
|
class M1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(1, 1)
|
|
self.fc2 = nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.fc2(x)
|
|
x = torch.add(x, x)
|
|
x = torch.add(x, x)
|
|
return x
|
|
|
|
class M2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(1, 1)
|
|
self.fc2 = nn.Linear(1, 1)
|
|
self.m1 = M1()
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.fc2(x)
|
|
x = torch.add(x, x)
|
|
x = torch.add(x, x)
|
|
x = self.m1(x)
|
|
return x
|
|
|
|
class M3(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(1, 1)
|
|
self.fc2 = nn.Linear(1, 1)
|
|
self.m2 = M2()
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.fc2(x)
|
|
x = torch.add(x, x)
|
|
x = torch.add(x, x)
|
|
x = self.m2(x)
|
|
return x
|
|
|
|
m = M3().eval()
|
|
qconfig_dict = {
|
|
"module_name_object_type_order": [
|
|
# test various FQNs: global, single child, multiple children
|
|
("", nn.Linear, 0, torch.ao.quantization.default_qconfig),
|
|
("", torch.add, 0, torch.ao.quantization.default_qconfig),
|
|
("m2", nn.Linear, 1, torch.ao.quantization.default_qconfig),
|
|
("m2", torch.add, 1, torch.ao.quantization.default_qconfig),
|
|
("m2.m1", nn.Linear, 0, torch.ao.quantization.default_qconfig),
|
|
("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig),
|
|
],
|
|
}
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
|
|
node_list = [
|
|
# m3
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Linear),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(nn.Linear),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(torch.add),
|
|
# m2
|
|
ns.call_module(nn.Linear),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Linear),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(torch.add),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
# m1
|
|
ns.call_module(nnq.Linear),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(nn.Linear),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(torch.add),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
# test that function order overrides global qconfig
|
|
class M4(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(1, 1)
|
|
self.fc2 = nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.fc2(x)
|
|
x = torch.add(x, x)
|
|
x = torch.add(x, x)
|
|
return x
|
|
|
|
m = M4().eval()
|
|
qconfig_dict = {
|
|
"": torch.ao.quantization.default_qconfig,
|
|
"module_name_object_type_order": [
|
|
("", nn.Linear, 1, None),
|
|
("", torch.add, 1, None),
|
|
],
|
|
}
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Linear),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(nn.Linear),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(torch.add),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
|
|
@override_qengines
|
|
def test_qconfig_dict_with_fused_modules(self):
|
|
class LinearReLUModel(torch.nn.Module):
|
|
def __init__(self, relu):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
self.relu = relu
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
class ConvReLUModel(torch.nn.Module):
|
|
def __init__(self, relu):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv1d(3, 3, 3)
|
|
self.relu = relu
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
class ConvBnReLUModel(torch.nn.Module):
|
|
def __init__(self, relu):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv1d(3, 3, 3)
|
|
self.bn = torch.nn.BatchNorm1d(3)
|
|
self.relu = relu
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
|
|
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
|
|
m = model(relu).eval()
|
|
qengine = torch.backends.quantized.engine
|
|
qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping(qengine)
|
|
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
|
|
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
|
|
|
# TODO: move QConfigMapping tests to test/quantization/core
|
|
def test_qconfig_mapping_set_global(self):
|
|
qconfig = get_default_qconfig()
|
|
qconfig_mapping = QConfigMapping()
|
|
self.assertEqual(qconfig_mapping.global_qconfig, None)
|
|
qconfig_mapping.set_global(qconfig)
|
|
self.assertEqual(qconfig_mapping.global_qconfig, qconfig)
|
|
|
|
def test_qconfig_mapping_set_object_type(self):
|
|
qconfig1 = get_default_qconfig()
|
|
qconfig2 = get_default_qconfig()
|
|
qconfig3 = get_default_qconfig()
|
|
self.assertNotEqual(qconfig1, qconfig2)
|
|
self.assertNotEqual(qconfig1, qconfig3)
|
|
qconfig_mapping = QConfigMapping()
|
|
self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 0)
|
|
# Insert some entries
|
|
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig1)
|
|
qconfig_mapping.set_object_type(torch.nn.ReLU, qconfig2)
|
|
self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 2)
|
|
self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig1)
|
|
self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
|
|
# Override existing key
|
|
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3)
|
|
self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3)
|
|
self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
|
|
self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3)
|
|
self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2)
|
|
self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None)
|
|
|
|
def test_qconfig_mapping_set_module_name_regex(self):
|
|
qconfig1 = get_default_qconfig()
|
|
qconfig2 = get_default_qconfig()
|
|
qconfig3 = get_default_qconfig()
|
|
self.assertNotEqual(qconfig1, qconfig2)
|
|
self.assertNotEqual(qconfig1, qconfig3)
|
|
qconfig_mapping = QConfigMapping()
|
|
self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 0)
|
|
# Insert some entries
|
|
qconfig_mapping.set_module_name_regex("foo.*bar", qconfig1)
|
|
qconfig_mapping.set_module_name_regex("foo.*", qconfig2)
|
|
self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 2)
|
|
self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig1)
|
|
self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
|
|
# Override existing key
|
|
qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3)
|
|
self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3)
|
|
self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
|
|
self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3)
|
|
self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3)
|
|
self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2)
|
|
self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2)
|
|
self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None)
|
|
|
|
def test_qconfig_mapping_set_module_name(self):
|
|
qconfig1 = get_default_qconfig()
|
|
qconfig2 = get_default_qconfig()
|
|
qconfig3 = get_default_qconfig()
|
|
self.assertNotEqual(qconfig1, qconfig2)
|
|
self.assertNotEqual(qconfig1, qconfig3)
|
|
qconfig_mapping = QConfigMapping()
|
|
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 0)
|
|
# Insert some entries
|
|
qconfig_mapping.set_module_name("mod1", qconfig1)
|
|
qconfig_mapping.set_module_name("mod2", qconfig2)
|
|
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
|
|
self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig1)
|
|
self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
|
|
# Override existing key
|
|
qconfig_mapping.set_module_name("mod1", qconfig3)
|
|
self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3)
|
|
self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
|
|
self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3)
|
|
self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2)
|
|
self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None)
|
|
|
|
def test_qconfig_mapping_set_module_name_object_type_order(self):
|
|
qconfig1 = get_default_qconfig()
|
|
qconfig2 = get_default_qconfig()
|
|
qconfig3 = get_default_qconfig()
|
|
self.assertNotEqual(qconfig1, qconfig2)
|
|
self.assertNotEqual(qconfig1, qconfig3)
|
|
qconfig_mapping = QConfigMapping()
|
|
self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 0)
|
|
# Insert some entries
|
|
qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig1)
|
|
qconfig_mapping.set_module_name_object_type_order("mod2", torch.nn.ReLU, 1, qconfig2)
|
|
self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2)
|
|
key1 = ("mod1", torch.nn.Linear, 0)
|
|
key2 = ("mod2", torch.nn.ReLU, 1)
|
|
self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1)
|
|
self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
|
|
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1)
|
|
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
|
|
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
|
|
qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1)
|
|
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
|
|
qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
|
|
# Override existing key
|
|
qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3)
|
|
self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2)
|
|
self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1)
|
|
self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
|
|
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3)
|
|
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
|
|
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
|
|
qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3)
|
|
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
|
|
qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
|
|
# No match
|
|
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
|
|
qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None)
|
|
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
|
|
qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None)
|
|
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
|
|
qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None)
|
|
|
|
def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2):
|
|
"""
|
|
Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods.
|
|
"""
|
|
return {
|
|
_GLOBAL_DICT_KEY: global_qconfig,
|
|
_OBJECT_TYPE_DICT_KEY: [
|
|
(torch.nn.Linear, qconfig1),
|
|
(torch.nn.ReLU, qconfig2),
|
|
],
|
|
_MODULE_NAME_REGEX_DICT_KEY: [
|
|
("foo.*bar", qconfig1),
|
|
("foo.*", qconfig2),
|
|
],
|
|
_MODULE_NAME_DICT_KEY: [
|
|
("bazbaz", qconfig1),
|
|
("borbor", qconfig2),
|
|
],
|
|
_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
|
|
("bazbaz", torch.nn.Linear, 0, qconfig1),
|
|
("foofoo", torch.nn.ReLU, 1, qconfig2),
|
|
],
|
|
}
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) # noqa: F821
|
|
self.assertTrue(
|
|
'Expected qconfig_dict to have the following keys:' in str(context.exception)
|
|
)
|
|
self.assertTrue('But found \'object_typo\' instead.' in str(context.exception))
|
|
|
|
def test_qconfig_mapping_from_dict(self):
|
|
global_qconfig = QConfig(123, "global")
|
|
qconfig1 = QConfig(1, "one")
|
|
qconfig2 = QConfig(2, "two")
|
|
qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2)
|
|
qconfig_dict["undefined_dict_key"] = [(123, qconfig1), (234, qconfig2)]
|
|
qconfig_mapping = QConfigMapping.from_dict(qconfig_dict)
|
|
self.assertEqual(qconfig_mapping.global_qconfig, global_qconfig)
|
|
self.assertEqual(qconfig_mapping.object_type_qconfigs, OrderedDict({
|
|
torch.nn.Linear: qconfig1,
|
|
torch.nn.ReLU: qconfig2,
|
|
}))
|
|
self.assertEqual(qconfig_mapping.module_name_regex_qconfigs, OrderedDict({
|
|
"foo.*bar": qconfig1,
|
|
"foo.*": qconfig2,
|
|
}))
|
|
self.assertEqual(qconfig_mapping.module_name_qconfigs, OrderedDict({
|
|
"bazbaz": qconfig1,
|
|
"borbor": qconfig2,
|
|
}))
|
|
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs, OrderedDict({
|
|
("bazbaz", torch.nn.Linear, 0): qconfig1,
|
|
("foofoo", torch.nn.ReLU, 1): qconfig2,
|
|
}))
|
|
|
|
def test_qconfig_mapping_to_dict(self):
|
|
global_qconfig = QConfig(123, "global")
|
|
qconfig1 = QConfig(1, "one")
|
|
qconfig2 = QConfig(2, "two")
|
|
qconfig_mapping = QConfigMapping().set_global(global_qconfig) \
|
|
.set_object_type(torch.nn.Linear, qconfig1) \
|
|
.set_object_type(torch.nn.ReLU, qconfig2) \
|
|
.set_module_name_regex("foo.*bar", qconfig1) \
|
|
.set_module_name_regex("foo.*", qconfig2) \
|
|
.set_module_name("bazbaz", qconfig1) \
|
|
.set_module_name("borbor", qconfig2) \
|
|
.set_module_name_object_type_order("bazbaz", torch.nn.Linear, 0, qconfig1) \
|
|
.set_module_name_object_type_order("foofoo", torch.nn.ReLU, 1, qconfig2)
|
|
qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2)
|
|
self.assertEqual(qconfig_mapping.to_dict(), qconfig_dict)
|
|
|
|
def test_qconfig_mapping_repr(self):
|
|
self.assertTrue(isinstance(get_default_qconfig_mapping().__repr__(), str))
|
|
|
|
def test_default_qconfig_mapping_override_global(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
m = M().eval()
|
|
my_qconfig = QConfig(activation=MinMaxObserver, weight=default_weight_observer)
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
# Override global qconfig
|
|
old_global_qconfig = qconfig_mapping.global_qconfig
|
|
qconfig_mapping.set_global(my_qconfig)
|
|
# Verify the correct qconfig was used
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs)
|
|
self.assertTrue(isinstance(old_global_qconfig.activation(), HistogramObserver))
|
|
self.assertTrue(isinstance(my_qconfig.activation(), MinMaxObserver))
|
|
self.assertTrue(hasattr(m, "activation_post_process_0"))
|
|
self.assertTrue(hasattr(m, "activation_post_process_1"))
|
|
self.assertTrue(isinstance(m.activation_post_process_0, MinMaxObserver))
|
|
self.assertTrue(isinstance(m.activation_post_process_1, MinMaxObserver))
|
|
|
|
# Dummy classes for PrepareCustomConfig testing
|
|
|
|
class _DummyStandaloneModule:
|
|
pass
|
|
|
|
class _DummyFloatModule:
|
|
pass
|
|
|
|
class _DummyObservedModule:
|
|
pass
|
|
|
|
class _DummyQuantizedModule:
|
|
pass
|
|
|
|
class _DummyNonTraceableModule1:
|
|
pass
|
|
|
|
class _DummyNonTraceableModule2:
|
|
pass
|
|
|
|
def test_prepare_custom_config_set_standalone_module_name(self):
|
|
qconfig_mapping = QConfigMapping()
|
|
example_inputs = (torch.randn(3),)
|
|
child_prepare_custom_config = PrepareCustomConfig()
|
|
backend_config = BackendConfig("my_backend")
|
|
config_entry = StandaloneModuleConfigEntry(
|
|
qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.standalone_module_names), 0)
|
|
prepare_custom_config.set_standalone_module_name(
|
|
"module1", qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
|
|
self.assertEqual(list(prepare_custom_config.standalone_module_names.keys()), ["module1"])
|
|
self.assertEqual(prepare_custom_config.standalone_module_names["module1"], config_entry)
|
|
|
|
def test_prepare_custom_config_set_standalone_module_class(self):
|
|
qconfig_mapping = QConfigMapping()
|
|
example_inputs = (torch.randn(3),)
|
|
child_prepare_custom_config = PrepareCustomConfig()
|
|
backend_config = BackendConfig("my_backend")
|
|
config_entry = StandaloneModuleConfigEntry(
|
|
qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.standalone_module_classes), 0)
|
|
prepare_custom_config.set_standalone_module_class(
|
|
self._DummyStandaloneModule, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
|
|
self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1)
|
|
self.assertTrue(self._DummyStandaloneModule in prepare_custom_config.standalone_module_classes)
|
|
self.assertEqual(prepare_custom_config.standalone_module_classes[self._DummyStandaloneModule], config_entry)
|
|
|
|
def test_prepare_custom_config_set_float_to_observed_mapping(self):
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 0)
|
|
prepare_custom_config.set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule, QuantType.STATIC)
|
|
self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1)
|
|
self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC])
|
|
self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1)
|
|
self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC])
|
|
self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule],
|
|
self._DummyObservedModule)
|
|
|
|
def test_prepare_custom_config_set_non_traceable_module_names(self):
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.non_traceable_module_names), 0)
|
|
prepare_custom_config.set_non_traceable_module_names(["module1", "module2"])
|
|
self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module1", "module2"])
|
|
|
|
def test_prepare_custom_config_set_non_traceable_module_classes(self):
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.non_traceable_module_classes), 0)
|
|
prepare_custom_config.set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
|
|
self.assertEqual(prepare_custom_config.non_traceable_module_classes,
|
|
[self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
|
|
|
|
def test_prepare_custom_config_set_input_quantized_indexes(self):
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.input_quantized_indexes), 0)
|
|
prepare_custom_config.set_input_quantized_indexes([0, 1])
|
|
self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1])
|
|
|
|
def test_prepare_custom_config_set_output_quantized_indexes(self):
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.output_quantized_indexes), 0)
|
|
prepare_custom_config.set_output_quantized_indexes([0, 1])
|
|
self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1])
|
|
|
|
def test_prepare_custom_config_set_preserved_attributes(self):
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
self.assertEqual(len(prepare_custom_config.preserved_attributes), 0)
|
|
prepare_custom_config.set_preserved_attributes(["attr1", "attr2"])
|
|
self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"])
|
|
|
|
def _get_dummy_prepare_custom_config_dict(self):
|
|
"""
|
|
Return a dummy prepare_custom_config_dict to test PrepareCustomConfig's to_dict and from_dict methods.
|
|
"""
|
|
return {
|
|
STANDALONE_MODULE_NAME_DICT_KEY: [(
|
|
"module1",
|
|
QConfigMapping(),
|
|
(torch.randn(3),),
|
|
PrepareCustomConfig(),
|
|
BackendConfig("my_backend"),
|
|
)],
|
|
STANDALONE_MODULE_CLASS_DICT_KEY: [(
|
|
self._DummyStandaloneModule,
|
|
QConfigMapping(),
|
|
(torch.randn(10),),
|
|
PrepareCustomConfig(),
|
|
BackendConfig("my_backend"),
|
|
)],
|
|
FLOAT_TO_OBSERVED_DICT_KEY: {
|
|
"static": {
|
|
self._DummyFloatModule: self._DummyObservedModule
|
|
},
|
|
},
|
|
NON_TRACEABLE_MODULE_NAME_DICT_KEY: ["module2", "module3"],
|
|
NON_TRACEABLE_MODULE_CLASS_DICT_KEY: [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2],
|
|
INPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1],
|
|
OUTPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1],
|
|
PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]
|
|
}
|
|
|
|
def test_prepare_custom_config_from_dict(self):
|
|
prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict()
|
|
(sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0]
|
|
(sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0]
|
|
sm_config_entry1 = StandaloneModuleConfigEntry(qm1, ei1, pcc1, bcd1)
|
|
sm_config_entry2 = StandaloneModuleConfigEntry(qm2, ei2, pcc2, bcd2)
|
|
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config_dict)
|
|
|
|
# Standalone modules
|
|
self.assertEqual(len(prepare_custom_config.standalone_module_names), 1)
|
|
self.assertTrue(sm_name in prepare_custom_config.standalone_module_names)
|
|
self.assertEqual(prepare_custom_config.standalone_module_names[sm_name], sm_config_entry1)
|
|
self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1)
|
|
self.assertTrue(sm_class in prepare_custom_config.standalone_module_classes)
|
|
self.assertEqual(prepare_custom_config.standalone_module_classes[sm_class], sm_config_entry2)
|
|
|
|
# Float to observed mapping
|
|
self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1)
|
|
self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC])
|
|
self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1)
|
|
self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC])
|
|
self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule],
|
|
self._DummyObservedModule)
|
|
|
|
# Other
|
|
self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module2", "module3"])
|
|
self.assertEqual(prepare_custom_config.non_traceable_module_classes,
|
|
[self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
|
|
self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1])
|
|
self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1])
|
|
self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"])
|
|
|
|
def test_prepare_custom_config_to_dict(self):
|
|
prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict()
|
|
(sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0]
|
|
(sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0]
|
|
prepare_custom_config = PrepareCustomConfig() \
|
|
.set_standalone_module_name(sm_name, qm1, ei1, pcc1, bcd1) \
|
|
.set_standalone_module_class(sm_class, qm2, ei2, pcc2, bcd2) \
|
|
.set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule) \
|
|
.set_non_traceable_module_names(["module2", "module3"]) \
|
|
.set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) \
|
|
.set_input_quantized_indexes([0, 1]) \
|
|
.set_output_quantized_indexes([0, 1]) \
|
|
.set_preserved_attributes(["attr1", "attr2"])
|
|
# PrepareCustomConfig.to_dict also converts internal QConfigMappings and PrepareCustomConfigs to dicts
|
|
prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] = (sm_name, qm1.to_dict(), ei1, pcc1.to_dict(), bcd1)
|
|
prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] = (sm_class, qm2.to_dict(), ei2, pcc2.to_dict(), bcd2)
|
|
self.assertEqual(prepare_custom_config.to_dict(), prepare_custom_config_dict)
|
|
|
|
def test_convert_custom_config_set_observed_to_quantized_mapping(self):
|
|
convert_custom_config = ConvertCustomConfig()
|
|
self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 0)
|
|
convert_custom_config.set_observed_to_quantized_mapping(
|
|
self._DummyObservedModule, self._DummyQuantizedModule, QuantType.STATIC)
|
|
self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1)
|
|
self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC])
|
|
self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC])
|
|
self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule],
|
|
self._DummyQuantizedModule)
|
|
|
|
def test_convert_custom_config_set_preserved_attributes(self):
|
|
convert_custom_config = ConvertCustomConfig()
|
|
self.assertEqual(len(convert_custom_config.preserved_attributes), 0)
|
|
convert_custom_config.set_preserved_attributes(["attr1", "attr2"])
|
|
self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"])
|
|
|
|
def _get_dummy_convert_custom_config_dict(self):
|
|
"""
|
|
Return a dummy convert_custom_config_dict to test ConvertCustomConfig's to_dict and from_dict methods.
|
|
"""
|
|
return {
|
|
OBSERVED_TO_QUANTIZED_DICT_KEY: {
|
|
"static": {
|
|
self._DummyObservedModule: self._DummyQuantizedModule
|
|
},
|
|
},
|
|
PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]
|
|
}
|
|
|
|
def test_convert_custom_config_from_dict(self):
|
|
convert_custom_config_dict = self._get_dummy_convert_custom_config_dict()
|
|
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config_dict)
|
|
self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1)
|
|
self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC])
|
|
self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]), 1)
|
|
self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC])
|
|
self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule],
|
|
self._DummyQuantizedModule)
|
|
self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"])
|
|
|
|
def test_convert_custom_config_to_dict(self):
|
|
convert_custom_config = ConvertCustomConfig() \
|
|
.set_observed_to_quantized_mapping(self._DummyObservedModule, self._DummyQuantizedModule) \
|
|
.set_preserved_attributes(["attr1", "attr2"])
|
|
self.assertEqual(convert_custom_config.to_dict(), self._get_dummy_convert_custom_config_dict())
|
|
|
|
def test_fuse_custom_config_set_preserved_attributes(self):
|
|
fuse_custom_config = FuseCustomConfig()
|
|
self.assertEqual(len(fuse_custom_config.preserved_attributes), 0)
|
|
fuse_custom_config.set_preserved_attributes(["attr1", "attr2"])
|
|
self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"])
|
|
|
|
def test_fuse_custom_config_from_dict(self):
|
|
fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]}
|
|
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config_dict)
|
|
self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"])
|
|
|
|
def test_fuse_custom_config_to_dict(self):
|
|
fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]}
|
|
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
|
|
self.assertEqual(fuse_custom_config.to_dict(), fuse_custom_config_dict)
|
|
|
|
def test_remove_qconfig(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.avg_pool = torch.nn.AvgPool2d(1)
|
|
|
|
def forward(self, x):
|
|
return self.avg_pool(x)
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {'': default_qconfig}
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
for name, module in m.named_modules():
|
|
self.assertFalse(hasattr(module, 'qconfig'),
|
|
'qconfig is not removed for ' + name)
|
|
|
|
def test_return_none(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
pass
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),))
|
|
m = convert_fx(m)
|
|
|
|
def test_default_quant_after_none_qconfig(self):
|
|
""" Make sure default quant is inserted properly"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = x.transpose(1, 2)
|
|
x = self.conv2(x)
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {
|
|
"": default_qconfig,
|
|
"module_name": [
|
|
("conv1", None)
|
|
]
|
|
}
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
|
|
m = convert_fx(m)
|
|
|
|
def test_qconfig_for_call_method(self):
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = x.transpose(2, 3)
|
|
x = self.conv(x)
|
|
return x.transpose(2, 3)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sub = Sub()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.sub(x)
|
|
x = self.conv2(x)
|
|
return x.transpose(2, 3)
|
|
|
|
qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]}
|
|
# since sub is configured to have qconfig None, we should dequantize the output
|
|
# of self.conv1 and quantize the input of self.conv2
|
|
# dequantize after conv2 should happen after transpose since
|
|
# it is configured with default_qconfig
|
|
# nodes in Sub module instance is not quantized
|
|
node_list1 = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("dequantize"),
|
|
ns.call_method("transpose"),
|
|
ns.call_module(nn.Conv2d),
|
|
ns.call_method("transpose"),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("transpose"),
|
|
ns.call_method("dequantize")
|
|
]
|
|
|
|
qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]}
|
|
# Only nodes in Sub module instance are quantized
|
|
# the first transpose is not quantized because the input is not quantized
|
|
node_list2 = [
|
|
ns.call_module(nn.Conv2d),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("transpose"),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("transpose"),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(nn.Conv2d),
|
|
ns.call_method("transpose"),
|
|
]
|
|
|
|
for qconfig_dict, node_list in [
|
|
(qconfig_dict1, node_list1),
|
|
(qconfig_dict2, node_list2)
|
|
]:
|
|
example_inputs = (torch.randn(2, 1, 3, 3),)
|
|
m = M().eval()
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m(torch.randn(2, 1, 3, 3))
|
|
m = convert_fx(m)
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
# make sure it runs
|
|
m(*example_inputs)
|
|
|
|
def test_qconfig_for_call_func(self):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
self.mods2 = Linear()
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = self.mods2(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
example_inputs = (torch.rand(5, 5),)
|
|
qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]}
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
|
|
m = convert_fx(m)
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.linear),
|
|
ns.call_function(torch.ops.quantized.linear),
|
|
ns.call_method('dequantize'),
|
|
ns.call_function(torch.nn.functional.linear)
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
m(torch.rand(5, 5))
|
|
|
|
def test_preserve_attributes(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
m = M()
|
|
m.eval()
|
|
m.preserved_attr = 3
|
|
prepare_custom_config_dict = {
|
|
"preserved_attributes": ["preserved_attr"]
|
|
}
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
m = prepare_fx(
|
|
m,
|
|
{"": default_qconfig},
|
|
example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
|
|
def assertAttrPreserved(m):
|
|
self.assertTrue(hasattr(m, "preserved_attr"))
|
|
self.assertEqual(m.preserved_attr, 3)
|
|
|
|
assertAttrPreserved(m)
|
|
convert_custom_config_dict = {
|
|
"preserved_attributes": ["preserved_attr"]
|
|
}
|
|
m = convert_fx(m, convert_custom_config=convert_custom_config_dict)
|
|
assertAttrPreserved(m)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qat_and_script(self):
|
|
model = LinearModelWithSubmodule().train()
|
|
qengine = torch.backends.quantized.engine
|
|
qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)}
|
|
x = torch.randn(5, 5)
|
|
example_inputs = (x,)
|
|
model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
|
|
# ensure scripting works
|
|
scripted = torch.jit.script(model)
|
|
# run one round to make sure model runs
|
|
scripted(x)
|
|
FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \
|
|
.run(scripted.graph)
|
|
|
|
# disable fake_quant and observer
|
|
for epoch in range(3):
|
|
if epoch == 1:
|
|
scripted.apply(torch.ao.quantization.disable_observer)
|
|
if epoch == 2:
|
|
scripted.apply(torch.ao.quantization.disable_fake_quant)
|
|
|
|
# ensure the fake_quant and observer have been disabled.
|
|
matches = ['.fake_quant_enabled', '.observer_enabled']
|
|
for key, v in scripted.state_dict().items():
|
|
if any(x in key for x in matches):
|
|
self.assertEqual(v, torch.tensor([0], dtype=torch.int64))
|
|
|
|
# enable them back
|
|
scripted.apply(torch.ao.quantization.enable_fake_quant)
|
|
scripted.apply(torch.ao.quantization.enable_observer)
|
|
for key, v in scripted.state_dict().items():
|
|
if any(x in key for x in matches):
|
|
self.assertEqual(v, torch.tensor([1], dtype=torch.int64))
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_save_observer_state_dict(self):
|
|
orig = LinearModelWithSubmodule().eval()
|
|
model = orig
|
|
qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')}
|
|
x = torch.randn(5, 5)
|
|
model = prepare_fx(model, qconfig_dict, example_inputs=(x,))
|
|
|
|
# run it through input
|
|
model(x)
|
|
# save state_dict of model
|
|
obs_dict = torch.ao.quantization.get_observer_state_dict(model)
|
|
|
|
quant = convert_fx(model)
|
|
|
|
b = io.BytesIO()
|
|
torch.save(obs_dict, b)
|
|
|
|
# Load the stats into new model
|
|
for weights_only in [True, False]:
|
|
b.seek(0)
|
|
model_2 = orig
|
|
model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,))
|
|
|
|
loaded_dict = torch.load(b, weights_only=weights_only)
|
|
torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict)
|
|
|
|
quant_2 = convert_fx(model_2)
|
|
|
|
# Verify that loaded state dict produces same results.
|
|
self.assertEqual(quant(x), quant_2(x))
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_custom_module_class(self):
|
|
class CustomModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class ObservedCustomModule(torch.nn.Module):
|
|
def __init__(self, linear):
|
|
super().__init__()
|
|
self.linear = linear
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
@classmethod
|
|
def from_float(cls, float_module):
|
|
assert hasattr(float_module, 'qconfig')
|
|
observed = cls(float_module.linear)
|
|
observed.qconfig = float_module.qconfig
|
|
return observed
|
|
|
|
class StaticQuantCustomModule(torch.nn.Module):
|
|
def __init__(self, linear):
|
|
super().__init__()
|
|
self.linear = linear
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
@classmethod
|
|
def from_observed(cls, observed_module):
|
|
assert hasattr(observed_module, 'qconfig')
|
|
assert hasattr(observed_module, 'activation_post_process')
|
|
observed_module.linear.activation_post_process = \
|
|
observed_module.activation_post_process
|
|
quantized = cls(nnq.Linear.from_float(observed_module.linear))
|
|
return quantized
|
|
|
|
class DynamicQuantCustomModule(torch.nn.Module):
|
|
def __init__(self, linear):
|
|
super().__init__()
|
|
self.linear = linear
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
@classmethod
|
|
def from_observed(cls, observed_module):
|
|
assert hasattr(observed_module, 'qconfig')
|
|
observed_module.linear.qconfig = observed_module.qconfig
|
|
quantized = cls(nnqd.Linear.from_float(observed_module.linear))
|
|
return quantized
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
self.custom = CustomModule()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.custom(x)
|
|
return x
|
|
|
|
class RefM(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(3, 3)
|
|
self.linear2 = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
# instantiate M and RefM and align the parameters
|
|
original_m = M().eval()
|
|
original_ref_m = RefM().eval()
|
|
original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach())
|
|
original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach())
|
|
original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach())
|
|
original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach())
|
|
|
|
a16_qconfig = QConfig(
|
|
activation=MinMaxObserver.with_args(dtype=torch.qint32, quant_min=0, quant_max=65536),
|
|
weight=default_weight_observer,
|
|
)
|
|
test_configs = {
|
|
"static": (default_qconfig, StaticQuantCustomModule, 3),
|
|
"static_a16": (a16_qconfig, StaticQuantCustomModule, 3),
|
|
"dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0)
|
|
}
|
|
|
|
for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]:
|
|
key = _get_quant_type_to_str(quant_type)
|
|
qconfig, quantized_module_class, num_observers = test_configs[key]
|
|
qconfig_dict = {"": qconfig}
|
|
if key == "static":
|
|
prepare_custom_config_dict = {
|
|
"float_to_observed_custom_module_class": {
|
|
"static": {
|
|
CustomModule: ObservedCustomModule
|
|
}
|
|
}
|
|
}
|
|
convert_custom_config_dict = {
|
|
"observed_to_quantized_custom_module_class": {
|
|
"static": {
|
|
ObservedCustomModule: quantized_module_class
|
|
}
|
|
}
|
|
}
|
|
else:
|
|
prepare_custom_config_dict = {
|
|
"non_traceable_module_class": [
|
|
CustomModule
|
|
]
|
|
}
|
|
convert_custom_config_dict = {
|
|
"observed_to_quantized_custom_module_class": {
|
|
"dynamic": {
|
|
CustomModule: quantized_module_class
|
|
}
|
|
}
|
|
}
|
|
|
|
example_inputs = (torch.randn(3, 3),)
|
|
# check prepared model
|
|
m = prepare_fx(
|
|
copy.deepcopy(original_m),
|
|
qconfig_dict,
|
|
example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
# calibration
|
|
m(*example_inputs)
|
|
# all activation observers are inserted in the top level module
|
|
count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
|
|
|
|
# check converted/quantized model
|
|
m = convert_fx(
|
|
m,
|
|
convert_custom_config=convert_custom_config_dict)
|
|
if quant_type == QuantType.STATIC:
|
|
count_check = {
|
|
ns.call_function(torch.quantize_per_tensor) : 1,
|
|
ns.call_module(nnq.Linear) : 1,
|
|
ns.call_method('dequantize') : 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
|
|
self.assertEqual(type(m.custom), quantized_module_class)
|
|
res = m(*example_inputs)
|
|
|
|
# quantize the reference model
|
|
ref_m = prepare_fx(
|
|
copy.deepcopy(original_ref_m), qconfig_dict, example_inputs=example_inputs)
|
|
ref_m(*example_inputs)
|
|
ref_m = convert_fx(ref_m)
|
|
ref_res = ref_m(*example_inputs)
|
|
self.assertEqual(res, ref_res)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_custom_module_class_input_has_multiple_users(self):
|
|
""" Tests that the flow still works when the input of custom module
|
|
has multiple users
|
|
"""
|
|
class CustomModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class ObservedCustomModule(torch.nn.Module):
|
|
def __init__(self, linear):
|
|
super().__init__()
|
|
self.linear = linear
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
@classmethod
|
|
def from_float(cls, float_module):
|
|
assert hasattr(float_module, 'qconfig')
|
|
observed = cls(float_module.linear)
|
|
observed.qconfig = float_module.qconfig
|
|
return observed
|
|
|
|
class StaticQuantCustomModule(torch.nn.Module):
|
|
def __init__(self, linear):
|
|
super().__init__()
|
|
self.linear = linear
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
@classmethod
|
|
def from_observed(cls, observed_module):
|
|
assert hasattr(observed_module, 'qconfig')
|
|
assert hasattr(observed_module, 'activation_post_process')
|
|
observed_module.linear.activation_post_process = \
|
|
observed_module.activation_post_process
|
|
quantized = cls(nnq.Linear.from_float(observed_module.linear))
|
|
return quantized
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
self.custom = CustomModule()
|
|
|
|
def forward(self, x0):
|
|
x1 = self.custom(x0)
|
|
x2 = self.linear(x0)
|
|
return x1 + x2
|
|
|
|
prepare_custom_config_dict = {
|
|
"float_to_observed_custom_module_class": {
|
|
"static": {
|
|
CustomModule: ObservedCustomModule
|
|
}
|
|
}
|
|
}
|
|
convert_custom_config_dict = {
|
|
"observed_to_quantized_custom_module_class": {
|
|
"static": {
|
|
ObservedCustomModule: StaticQuantCustomModule
|
|
}
|
|
}
|
|
}
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(3, 3),)
|
|
m = prepare_fx(
|
|
m,
|
|
{"": default_qconfig},
|
|
example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
# make sure it works
|
|
m = convert_fx(
|
|
m,
|
|
convert_custom_config=convert_custom_config_dict)
|
|
# make sure it runs
|
|
m(*example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_custom_module_class_input_has_duplicate_nodes(self):
|
|
""" Tests that the flow still works when the graph has
|
|
multiple nodes with the same custom module target.
|
|
"""
|
|
class CustomModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class ObservedCustomModule(torch.nn.Module):
|
|
def __init__(self, linear):
|
|
super().__init__()
|
|
self.linear = linear
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
@classmethod
|
|
def from_float(cls, float_module):
|
|
assert hasattr(float_module, 'qconfig')
|
|
observed = cls(float_module.linear)
|
|
observed.qconfig = float_module.qconfig
|
|
return observed
|
|
|
|
class StaticQuantCustomModule(torch.nn.Module):
|
|
def __init__(self, linear):
|
|
super().__init__()
|
|
self.linear = linear
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
@classmethod
|
|
def from_observed(cls, observed_module):
|
|
assert hasattr(observed_module, 'qconfig')
|
|
assert hasattr(observed_module, 'activation_post_process')
|
|
observed_module.linear.activation_post_process = \
|
|
observed_module.activation_post_process
|
|
quantized = cls(nnq.Linear.from_float(observed_module.linear))
|
|
return quantized
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.custom = CustomModule()
|
|
|
|
def forward(self, x0):
|
|
x1 = self.custom(x0)
|
|
x2 = self.custom(x0)
|
|
return x1 + x2
|
|
|
|
prepare_custom_config_dict = {
|
|
"float_to_observed_custom_module_class": {
|
|
"static": {
|
|
CustomModule: ObservedCustomModule
|
|
}
|
|
}
|
|
}
|
|
convert_custom_config_dict = {
|
|
"observed_to_quantized_custom_module_class": {
|
|
"static": {
|
|
ObservedCustomModule: StaticQuantCustomModule
|
|
}
|
|
}
|
|
}
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(3, 3),)
|
|
m = prepare_fx(
|
|
m,
|
|
{"": default_qconfig},
|
|
example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
# make sure it works
|
|
m = convert_fx(
|
|
m,
|
|
convert_custom_config=convert_custom_config_dict)
|
|
# make sure it runs
|
|
m(*example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_non_traceable_module(self):
|
|
class NonTraceable(torch.nn.Module):
|
|
def forward(self, x):
|
|
for k in x.keys():
|
|
print(x[k])
|
|
return x
|
|
|
|
class NonTraceable2(torch.nn.Module):
|
|
def forward(self, x):
|
|
# data dependent control flow is not traceable
|
|
for i in x:
|
|
print(i)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m1 = NonTraceable()
|
|
self.m2 = NonTraceable2()
|
|
|
|
def forward(self, x):
|
|
x = self.m1(x)
|
|
x = self.m2(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {"": default_qconfig}
|
|
prepare_custom_config_dict = {
|
|
"non_traceable_module_name": [
|
|
"m1"
|
|
],
|
|
"non_traceable_module_class": [
|
|
NonTraceable2
|
|
]
|
|
}
|
|
m = prepare_fx(
|
|
m, qconfig_dict,
|
|
example_inputs=({"key": torch.randn(1)},),
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
|
|
node_occurrence = {
|
|
ns.call_module(NonTraceable) : 1,
|
|
ns.call_module(NonTraceable2) : 1,
|
|
}
|
|
# make sure these modules are not traced
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_prepared_model_deepcopy(self):
|
|
"""Ensures that copy.deepcopy works correctly on a prepared model.
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
self._foobar = 'foobar'
|
|
self.foobar2 = 'foobar2'
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
m = M()
|
|
m.eval()
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
example_inputs = (torch.randn(4, 1, 4, 4),)
|
|
prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
# calibrate
|
|
prepared(*example_inputs)
|
|
# copy
|
|
prepared_copy = copy.deepcopy(prepared)
|
|
# quantize, should run with no errors
|
|
quantized = convert_fx(prepared_copy)
|
|
|
|
def test_quantized_model_type(self):
|
|
""" Test state_dict and deepcopy works properly in the quantized model
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
example_inputs = (torch.rand(8, 5),)
|
|
m = M().eval()
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
# test deepcopy
|
|
m_copy = copy.deepcopy(m)
|
|
self.assertEqual(m_copy(*example_inputs), m(*example_inputs))
|
|
|
|
# test state_dict
|
|
state_dict = m.state_dict()
|
|
m_new = M().eval()
|
|
m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs)
|
|
m_new = convert_fx(m_new)
|
|
m_new.load_state_dict(state_dict)
|
|
self.assertEqual(m_new(*example_inputs), m(*example_inputs))
|
|
|
|
def test_dequantize(self):
|
|
r""" Test to make sure dequantize node are placed before
|
|
non-quantizable node
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
self.act = torch.nn.GELU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return self.act(x)
|
|
|
|
data = torch.rand(5, 1, 3, 3, dtype=torch.float)
|
|
for quant_type in self.static_quant_types:
|
|
node_list = [
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(nn.GELU),
|
|
]
|
|
self.checkGraphModeFxOp(
|
|
M().eval(), (data,), quant_type, expected_node_list=node_list)
|
|
|
|
def test_sequential(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.convs = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, 1, 1),
|
|
torch.nn.Conv2d(1, 1, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.convs(x)
|
|
return x
|
|
|
|
data = torch.rand(5, 1, 3, 3, dtype=torch.float)
|
|
for quant_type in self.static_quant_types:
|
|
node_list = [
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_module(nnq.Conv2d),
|
|
]
|
|
self.checkGraphModeFxOp(
|
|
M().eval(), (data,), quant_type, expected_node_list=node_list)
|
|
|
|
def _test_quantized_inputs_outputs(
|
|
self, prepare_custom_config_dict, prepare_count_check,
|
|
convert_count_check):
|
|
"""
|
|
Test the option to have inputs and outputs of the graph quantized
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
# quantized input, quantized output
|
|
m = M()
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
example_inputs = (torch.randn(1, 1, 4, 4),)
|
|
m.eval()
|
|
mp = torch.ao.quantization.quantize_fx.prepare_fx(
|
|
m, qconfig_dict,
|
|
example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
|
|
mp(*example_inputs)
|
|
mq = torch.ao.quantization.quantize_fx.convert_fx(mp)
|
|
self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)
|
|
|
|
def test_quantized_input_quantized_output(self):
|
|
prepare_custom_config_dict = {
|
|
'input_quantized_idxs': [0], 'output_quantized_idxs': [0]}
|
|
prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
|
|
}
|
|
convert_count_check = {
|
|
ns.call_function(torch.quantize_per_tensor): 0,
|
|
ns.call_method('dequantize'): 0,
|
|
}
|
|
self._test_quantized_inputs_outputs(
|
|
prepare_custom_config_dict, prepare_count_check, convert_count_check)
|
|
|
|
def test_fp32_input_quantized_output(self):
|
|
prepare_custom_config_dict = {
|
|
'output_quantized_idxs': [0]}
|
|
prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
|
|
}
|
|
convert_count_check = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method('dequantize'): 0,
|
|
}
|
|
self._test_quantized_inputs_outputs(
|
|
prepare_custom_config_dict, prepare_count_check, convert_count_check)
|
|
|
|
def test_quantized_input_fp32_output(self):
|
|
prepare_custom_config_dict = {
|
|
'input_quantized_idxs': [0]}
|
|
prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
|
|
}
|
|
convert_count_check = {
|
|
ns.call_function(torch.quantize_per_tensor): 0,
|
|
ns.call_method('dequantize'): 1,
|
|
}
|
|
self._test_quantized_inputs_outputs(
|
|
prepare_custom_config_dict, prepare_count_check, convert_count_check)
|
|
|
|
def test_fp32_input_fp32_output(self):
|
|
prepare_custom_config_dict = {}
|
|
prepare_count_check = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
|
|
}
|
|
convert_count_check = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method('dequantize'): 1,
|
|
}
|
|
self._test_quantized_inputs_outputs(
|
|
prepare_custom_config_dict, prepare_count_check, convert_count_check)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_convtranspose_per_channel_fails_early(self):
|
|
r"""
|
|
Verifies that attempting to quantize a ConvTranspose module with per-Channel
|
|
weight observers fails in the prepare step, as opposed to the convert step.
|
|
"""
|
|
m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
|
|
m.eval()
|
|
qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')}
|
|
with self.assertRaises(AssertionError) as context:
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
|
|
self.assertTrue(
|
|
str(context.exception) ==
|
|
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qparams_buffers(self):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
self.mods2 = Linear()
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = self.mods2(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
qconfig_dict = {"": default_qconfig}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
keys = m.state_dict().keys()
|
|
quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0
|
|
for k in keys:
|
|
if 'input_scale' in k:
|
|
quant_scale_count = quant_scale_count + 1
|
|
elif 'input_zero_point' in k:
|
|
quant_zero_point = quant_zero_point + 1
|
|
elif 'scale' in k:
|
|
scale_count = scale_count + 1
|
|
elif 'zero_point' in k:
|
|
zero_point_count = zero_point_count + 1
|
|
|
|
# Expect each quantized linear op to have a scale and zero point
|
|
self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict")
|
|
self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict")
|
|
m(*example_inputs)
|
|
# ensure it is scriptable
|
|
scripted = torch.jit.script(m)
|
|
scripted_keys = scripted.state_dict().keys()
|
|
scripted.mods1_0_packed_weight_0 = m.state_dict()["mods1_0_packed_weight_0"]
|
|
non_packed_weight_keys = [key for key in keys if "_packed_weight" not in key]
|
|
self.assertTrue(
|
|
set(scripted_keys) == set(non_packed_weight_keys),
|
|
"Expected the scripted model to preserve the state_dict for non-packed weight attributes")
|
|
# TODO: probably don't want to hardcode the attribute names, since they are generated
|
|
for attr_name in [
|
|
"mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
|
|
"mods1_0_scale_1", "mods1_0_zero_point_1",
|
|
"mods1_1_scale_1", "mods1_1_zero_point_1",
|
|
"mods2_scale_1", "mods2_zero_point_1"]:
|
|
self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_packed_weight_fused_op(self):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return F.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
self.mods2 = Linear()
|
|
self.relu = F.relu
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = self.mods2(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
example_inputs = (torch.rand(5, 5),)
|
|
qconfig_dict = {"": default_qconfig}
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
assert hasattr(m, "mods1_0_packed_weight_0")
|
|
assert hasattr(m, "mods1_1_packed_weight_0")
|
|
assert hasattr(m, "mods2_packed_weight_0")
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_mul_add_fp16_config(self):
|
|
with override_quantized_engine('fbgemm'):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
self.mods2 = Linear()
|
|
|
|
def forward(self, x):
|
|
x = x * 5
|
|
x = x + 5
|
|
x = self.mods1(x)
|
|
x = self.mods2(x)
|
|
return x
|
|
model = M().eval()
|
|
qconfig_dict = {"": float16_dynamic_qconfig}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
# make sure it runs
|
|
m(*example_inputs)
|
|
|
|
def test_getattr_with_nontensor_result(self):
|
|
"""
|
|
Verifies that binary ops get quantized correctly if some
|
|
of the args are nodes but not Tensors, such as an `x.ndim`
|
|
pattern.
|
|
"""
|
|
class M1(torch.nn.Module):
|
|
def forward(self, x):
|
|
dims = x.ndim
|
|
dims_sub = dims - 1
|
|
dims_sub2 = dims_sub - 1
|
|
x = torch.add(x, dims_sub2)
|
|
return x
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x):
|
|
dims = x.ndim
|
|
dims_sub = dims - 2
|
|
mul = [1] * dims_sub
|
|
dims_list = [-1, x.size(1)] + mul
|
|
x = x.view(dims_list)
|
|
return x
|
|
|
|
class M3(torch.nn.Module):
|
|
def forward(self, x):
|
|
shape = x.shape
|
|
x = x.view(shape)
|
|
return x
|
|
|
|
for cls in (M1, M2, M3):
|
|
m = cls().eval()
|
|
example_inputs = (torch.rand(4, 4, 4, 4),)
|
|
m(*example_inputs)
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
mp(torch.rand(4, 4, 4, 4))
|
|
mc = convert_fx(mp)
|
|
|
|
class _NonReferenceTestModel(nn.Module):
|
|
def __init__(self, func, lin_in, lin_out):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.lin = nn.Linear(lin_in, lin_out)
|
|
self.func = func
|
|
|
|
def forward(self, x, y, z):
|
|
x = self.pool(F.relu(self.conv1(x)))
|
|
x = torch.flatten(x, 1)
|
|
x = self.func(x, y, z)
|
|
x = self.lin(x)
|
|
return x
|
|
|
|
# This function looks at the node specified by the NodeInfo in the key of
|
|
# node_info_to_non_tensor_args and checks that the args at specified indices
|
|
# are not observed (since they are non tensors). If the args at those indices
|
|
# are a tuple/list (which do not show up as nodes) the function checks the
|
|
# individual elements of the tuple/list recursively.
|
|
def _check_not_observed(self, model, node_info_to_non_tensor_args):
|
|
|
|
# this is a helper function (for easier recursion) that checks whether
|
|
# arg_node is observed
|
|
def _check_node_not_observed(model, arg_node, node):
|
|
if isinstance(arg_node, (tuple, list)):
|
|
for new_node in arg_node:
|
|
_check_node_not_observed(model, new_node, node)
|
|
elif arg_node.op == "call_module":
|
|
self.assertTrue(
|
|
not _is_activation_post_process(getattr(model, arg_node.target)),
|
|
f"Arg: {arg_node} of node: {node} is observed but is not a float tensor",
|
|
)
|
|
|
|
for node in model.graph.nodes:
|
|
indices = node_info_to_non_tensor_args.get(
|
|
NodeInfo(node.op, node.target), []
|
|
)
|
|
for index in indices:
|
|
if index < len(node.args):
|
|
arg_node = node.args[index]
|
|
_check_node_not_observed(model, arg_node, node)
|
|
|
|
# This test checks that the model gets prepared correct, doesn't have observers
|
|
# on specific ops (see _check_not_observed) and that the prepared model runs
|
|
def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args):
|
|
model.eval()
|
|
qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")}
|
|
prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args))
|
|
self._check_not_observed(prepared_model, node_info_to_non_tensor_args)
|
|
prepared_model(*args)
|
|
|
|
def test_masked_fill_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.masked_fill(y, z)
|
|
|
|
model = self._NonReferenceTestModel(func, 1176, 1)
|
|
args = [torch.randn(5, 3, 32, 32), torch.randn(1176) > 0, 0.1]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "masked_fill"): [1, 2]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_permute_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.permute(y, z)
|
|
|
|
model = self._NonReferenceTestModel(func, 1176, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 0, 1]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "permute"): [1, 2]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_repeat_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.repeat(y, z)
|
|
|
|
model = self._NonReferenceTestModel(func, 1176, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 2, 1]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "repeat"): [1, 2]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_reshape_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.reshape(-1, y)
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 5, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [2]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_size_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.reshape((-1, x.size(y)))
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 0, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "size"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_transpose_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.transpose(y, z)
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 0, 1]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_torch_transpose_nontensor_args_not_observed(self):
|
|
# TODO: make torch.transpose traceable by fx when using
|
|
# variable nontensor arguments
|
|
# func = lambda x, y, z: torch.transpose(x, y, z) # error
|
|
def func(x, y, z):
|
|
return torch.transpose(x, 0, 1)
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
node_info_to_non_tensor_args = {
|
|
NodeInfo("call_method", torch.transpose): [1, 2]
|
|
}
|
|
args = [torch.randn(5, 3, 32, 32), 0, 1]
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_unsqueeze_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.unsqueeze(y)
|
|
|
|
model = self._NonReferenceTestModel(func, 1176, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 1, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_unsqueeze__nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.unsqueeze_(y)
|
|
|
|
model = self._NonReferenceTestModel(func, 1176, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 1, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze_"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_torch_unsqueeze_nontensor_args_not_observed(self):
|
|
# TODO: make torch.unsqueeze scriptable by fx when using
|
|
# variable nontensor arguments
|
|
# func = lambda x, y, z: torch.unsqueeze(x, y) # error
|
|
def func(x, y, z):
|
|
return torch.unsqueeze(x, 1)
|
|
|
|
model = self._NonReferenceTestModel(func, 1176, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 1, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", torch.unsqueeze): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_view_nontensor_args_not_observed(self):
|
|
def func(x, y, z):
|
|
return x.view(-1, y)
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), 5, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "view"): [2]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_propagate_dtypes_for_known_nodes_list_args(self):
|
|
def func(x, y, z):
|
|
return x.reshape(y)
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), [-1, 5], None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_propagate_dtypes_for_known_nodes_split_list_args(self):
|
|
def func(x, y, z):
|
|
return x.reshape([y, z])
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), -1, 5]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_propagate_dtypes_for_known_nodes_tuple_args(self):
|
|
def func(x, y, z):
|
|
return x.reshape(y)
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), (-1, 5), None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_propagate_dtypes_for_known_nodes_split_tuple_args(self):
|
|
def func(x, y, z):
|
|
return x.reshape((y, z))
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), -1, 5]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_propagate_dtypes_for_known_nodes_dict_args(self):
|
|
def func(x, y, z):
|
|
return x.transpose(y["first"], y["second"])
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), {"first": 0, "second": 1}, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_propagate_dtypes_for_known_nodes_dict_tuple_args(self):
|
|
class reshape_module(nn.Module):
|
|
def forward(self, x, y, z):
|
|
return x.reshape(y["shape"])
|
|
|
|
model = self._NonReferenceTestModel(reshape_module(), 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), {"shape": (-1, 5)}, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_propagate_dtypes_for_known_nodes_dict_split_tuple_args(self):
|
|
def func(x, y, z):
|
|
return x.reshape((y["first"], y["second"]))
|
|
|
|
model = self._NonReferenceTestModel(func, 5, 1)
|
|
args = [torch.randn(5, 3, 32, 32), {"first": -1, "second": 5}, None]
|
|
node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1]}
|
|
self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
|
|
|
|
def test_assert_on_size_after_quant_layer(self):
|
|
"""
|
|
Verifies that calculating a size of a quantized tensor works
|
|
correctly in quantization passes.
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
torch._assert(x.size(1) == 1, 'foobar')
|
|
return x
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.rand(4, 1, 4, 4),)
|
|
m(*example_inputs)
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
mp(*example_inputs)
|
|
mc = convert_fx(mp)
|
|
mc(*example_inputs)
|
|
|
|
def test_fp32_sum(self):
|
|
"""
|
|
Verifies that fp32 sum works correctly if it's before or after
|
|
quantized layers.
|
|
"""
|
|
class M1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = torch.stack([x])
|
|
x = torch.sum(x)
|
|
return x
|
|
|
|
class M2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 1, 1)
|
|
self.conv2 = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x1 = torch.stack([x])
|
|
x1 = torch.sum(x1, dim=0)
|
|
x2 = self.conv2(x1)
|
|
return x2
|
|
|
|
for cls in (M1, M2):
|
|
m = cls().eval()
|
|
example_inputs = (torch.rand(4, 1, 4, 4),)
|
|
m(*example_inputs)
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
mp(*example_inputs)
|
|
mc = convert_fx(mp)
|
|
mc(*example_inputs)
|
|
|
|
def test_fusion_pattern_unquantized(self):
|
|
"""
|
|
Ensure that leaving a possible fusion pattern of multiple nodes
|
|
unquantized runs through the APIs without errors.
|
|
"""
|
|
class Child(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = torch.add(x, 1.0)
|
|
x = torch.nn.functional.relu(x)
|
|
return x
|
|
|
|
class Parent(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child = Child()
|
|
self.conv = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.child(x)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
m = Parent().eval()
|
|
qconfig_dict = {
|
|
'': torch.ao.quantization.default_qconfig,
|
|
'module_name': [
|
|
('child', None),
|
|
],
|
|
}
|
|
example_inputs = (torch.rand(1, 1, 1, 1),)
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
mp(*example_inputs)
|
|
mc = convert_fx(mp)
|
|
|
|
def test_state_dict(self):
|
|
""" Make sure packed params appear in state_dict
|
|
"""
|
|
|
|
# test linear packed weight
|
|
class M1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.rand(4, 30)
|
|
self.b = torch.rand(4)
|
|
|
|
def forward(self, x):
|
|
return F.linear(x, self.w, self.b)
|
|
|
|
m = M1().eval()
|
|
qconfig_dict = {"": default_qconfig}
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),))
|
|
m = convert_fx(m)
|
|
state_dict = m.state_dict()
|
|
self.assertTrue("_packed_weight_0" in state_dict)
|
|
|
|
# test conv packed weight
|
|
class M2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.rand(3, 3, 3, 3)
|
|
self.b = torch.rand(3)
|
|
self.stride = (1, 1)
|
|
self.padding = (0, 0)
|
|
self.dilation = (1, 1)
|
|
self.groups = 1
|
|
|
|
def forward(self, x):
|
|
return F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
m = M2().eval()
|
|
qconfig_dict = {"": default_qconfig}
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
|
m = convert_fx(m)
|
|
state_dict = m.state_dict()
|
|
self.assertTrue("_packed_weight_0" in state_dict)
|
|
|
|
# test load
|
|
ref_weight, ref_bias = torch.ops.quantized.conv2d_unpack(state_dict["_packed_weight_0"])
|
|
data = torch.rand(1, 3, 5, 5)
|
|
ref_res = m(data)
|
|
m = M2().eval()
|
|
m = prepare_fx(m, qconfig_dict, (data,))
|
|
m = convert_fx(m)
|
|
res = m(data)
|
|
weight, bias = m._packed_weight_0.unpack()
|
|
# check that random model weight/bias does not match ref weight/bias
|
|
self.assertNotEqual(weight, ref_weight)
|
|
self.assertNotEqual(bias, ref_bias)
|
|
self.assertNotEqual(res, ref_res)
|
|
m.load_state_dict(state_dict)
|
|
|
|
def checkModel(m, data, ref_weight, ref_bias, ref_res):
|
|
res = m(data)
|
|
weight, bias = m._packed_weight_0.unpack()
|
|
# check that weight/bias matches after load the state_dict
|
|
self.assertEqual(weight, ref_weight)
|
|
self.assertEqual(bias, ref_bias)
|
|
self.assertEqual(res, ref_res)
|
|
|
|
checkModel(m, data, ref_weight, ref_bias, ref_res)
|
|
|
|
# Test save to disk and load back
|
|
m = M2().eval()
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=(data,))
|
|
m = convert_fx(m)
|
|
m.load_state_dict(state_dict)
|
|
with TemporaryFileName() as fname:
|
|
torch.save(m.state_dict(), fname)
|
|
# weights_only=False as this is loading a ScriptModule
|
|
m.load_state_dict(torch.load(fname, weights_only=False))
|
|
|
|
checkModel(m, data, ref_weight, ref_bias, ref_res)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_preserve_qconfig(self):
|
|
"""
|
|
Test to make sure the temporary config option to preserve qconfig attributes
|
|
in the model works
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
self.mods2 = torch.nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = self.mods2(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
qconfig_dict = {
|
|
"object_type": [
|
|
(torch.nn.functional.linear, float16_dynamic_qconfig),
|
|
],
|
|
}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m, _remove_qconfig=False)
|
|
|
|
self.assertTrue(hasattr(m.mods2, 'qconfig'))
|
|
|
|
def test_not_used(self):
|
|
""" Test quantizing a not used value"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + x
|
|
x.sigmoid_()
|
|
return x
|
|
|
|
m = M().eval()
|
|
qconfig_mapping = get_default_qconfig_mapping().set_global(float16_static_qconfig)
|
|
# make sure quantization runs
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs=(torch.randn(1),))
|
|
m = convert_fx(m)
|
|
|
|
def test_qparams_fqn(self):
|
|
""" Test that the FQN of input_scale/zero_point is set
|
|
to that of first linear use. """
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = torch.cat((x,), 1)
|
|
tmp = x.size()
|
|
x = self.mods1(x)
|
|
y = x * tmp[0]
|
|
return y
|
|
|
|
model = M().eval()
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [
|
|
(torch.nn.functional.linear, default_qconfig),
|
|
(torch.nn.functional.relu, default_qconfig),
|
|
],
|
|
}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
keys = m.state_dict().keys()
|
|
m(torch.randn(5, 5))
|
|
# TODO: probably don't want to hardcode the attribute names, since they are generated
|
|
for attr_name in [
|
|
"mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
|
|
"mods1_0_scale_0", "mods1_0_zero_point_0",
|
|
"mods1_1_scale_0", "mods1_1_zero_point_0"]:
|
|
self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")
|
|
|
|
def test_no_obs_between_unmatched_node_and_copy_node(self):
|
|
"""
|
|
Verifies that an observer is not inserted between an unmatched
|
|
node and a node matched to CopyNodeQuantizeHandler. This is done
|
|
because observers require activations to be Tensors, and there is
|
|
no guarantee that an output of an unmatched node is a Tensor.
|
|
"""
|
|
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = _user_func_with_complex_return_type(x)
|
|
x1 = x[0] + 1
|
|
return x1, x[1]
|
|
|
|
m = M().eval()
|
|
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
example_inputs = (torch.randn(4, 4, 4, 4),)
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
# if an observer is inserted after _user_func_with_complex_return_type,
|
|
# the following call will fail
|
|
mp(*example_inputs)
|
|
mc = convert_fx(mp)
|
|
mc(*example_inputs)
|
|
|
|
def test_fold_quant_dequant(self):
|
|
""" Test that the sequence of quant-dequant nodes in the
|
|
graph, get folded and we erase the extra dequant nodes.
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
x = torch.cat((x,), 1)
|
|
tmp = x.size()
|
|
x = torch.nn.functional.linear(x, self.w, self.b)
|
|
y = x * tmp[0]
|
|
return y
|
|
|
|
model = M().eval()
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [
|
|
(torch.nn.functional.linear, default_qconfig),
|
|
],
|
|
}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
keys = m.state_dict().keys()
|
|
m(*example_inputs)
|
|
dequant = 0
|
|
quant = 0
|
|
for n in m.graph.nodes:
|
|
if n.op == "call_method" and n.target == "dequantize":
|
|
dequant = dequant + 1
|
|
if n.op == "call_function" and n.target == torch.quantize_per_tensor:
|
|
quant = quant + 1
|
|
self.assertEqual(dequant, 1)
|
|
self.assertEqual(quant, 1)
|
|
|
|
def test_quant_output_always_observed(self):
|
|
"""
|
|
If the output is hardcoded to be quantized, ensure that
|
|
there is always an observer, even if the last non-output node is not
|
|
quantizeable.
|
|
"""
|
|
qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
|
|
prepare_custom_config_dict = {'output_quantized_idxs': [0]}
|
|
example_inputs = (torch.randn(4, 1, 4, 4),)
|
|
|
|
# non-quantizeable node, quantized output
|
|
class M1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.identity = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
x = self.identity(x)
|
|
return x
|
|
|
|
m1 = M1()
|
|
self.checkGraphModeFxOp(
|
|
m1, example_inputs, QuantType.QAT,
|
|
prepare_expected_node_occurrence={
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
|
|
},
|
|
expected_node_occurrence={
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
},
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
|
|
# quantizeable node, quantized output
|
|
class M2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
m2 = M2()
|
|
self.checkGraphModeFxOp(
|
|
m2, example_inputs, QuantType.QAT,
|
|
prepare_expected_node_occurrence={
|
|
# one for weights, one for activations
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
|
|
},
|
|
expected_node_occurrence={
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
},
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
|
|
# quantizeable node, quantized dictionary output
|
|
class M3(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return {"output": x}
|
|
|
|
m3 = M3()
|
|
self.checkGraphModeFxOp(
|
|
m3, example_inputs, QuantType.QAT,
|
|
prepare_expected_node_occurrence={
|
|
# one for weights, one for activations
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
|
|
},
|
|
expected_node_occurrence={
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
},
|
|
prepare_custom_config=prepare_custom_config_dict)
|
|
|
|
def test_deepcopy_preserve_attributes(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = 3
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
m = M().eval()
|
|
m = prepare_fx(
|
|
m,
|
|
{"": default_qconfig},
|
|
example_inputs=(torch.randn(1),),
|
|
prepare_custom_config={"preserved_attributes": ["attr"]})
|
|
# preserved attributes are also stored in meta so that it doesn't get lost
|
|
# during deepcopy
|
|
self.assertTrue(hasattr(m, "attr"))
|
|
self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
|
|
m2 = copy.deepcopy(m)
|
|
self.assertTrue(hasattr(m2, "attr"))
|
|
self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
|
|
m = convert_fx(m, convert_custom_config={"preserved_attributes": ["attr"]})
|
|
self.assertTrue(hasattr(m, "attr"))
|
|
self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
|
|
m2 = copy.deepcopy(m)
|
|
self.assertTrue(hasattr(m2, "attr"))
|
|
self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
|
|
|
|
def test_output_lists_and_dicts(self):
|
|
"""Verify that specifying complicated output types does not crash.
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return {'foo': [x]}, [{'foo': [[x]]}]
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
|
|
mc = convert_fx(mp)
|
|
|
|
def test_shape_followed_by_quantized_op(self):
|
|
""" Make sure that shape does not dequantize
|
|
the Tensor before the next operator
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2)
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
s = x.shape
|
|
torch._assert(s == x.shape, "")
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
# make sure quantization runs
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(2, 2, 4, 4),)
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_trace_quantize_per_tensor(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),))
|
|
m = convert_fx(m)
|
|
# Make sure this runs without error
|
|
m = torch.fx.Transformer(m).transform()
|
|
|
|
def test_copy_node_has_shared_actpp_instance(self):
|
|
""" Test the output of CopyNode to have the same
|
|
observer/fake_quant instance as the input
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.avgpool2d = torch.nn.AvgPool2d(kernel_size=3)
|
|
|
|
def forward(self, x):
|
|
x = self.avgpool2d(x)
|
|
return x
|
|
|
|
for quant_type in self.static_quant_types:
|
|
m = M()
|
|
# Checks that we have an observer for both input and output
|
|
occurrence_map = {
|
|
QuantType.STATIC: {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 2
|
|
},
|
|
QuantType.QAT: {
|
|
ns.call_module(torch.ao.quantization.FakeQuantize): 2
|
|
}
|
|
}
|
|
if quant_type == QuantType.QAT:
|
|
m.train()
|
|
prepare = prepare_qat_fx
|
|
qconfig = default_qat_qconfig
|
|
actpp_module_class = torch.ao.quantization.FakeQuantize
|
|
else:
|
|
m.eval()
|
|
prepare = prepare_fx
|
|
qconfig = default_qconfig
|
|
actpp_module_class = torch.ao.quantization.MinMaxObserver
|
|
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
m = prepare(m, {"": qconfig}, example_inputs=example_inputs)
|
|
# check that there is a duplicated observer instance
|
|
actpp_module_count = 0
|
|
for name, module in m.named_modules(remove_duplicate=False):
|
|
if isinstance(module, actpp_module_class):
|
|
actpp_module_count += 1
|
|
self.assertEqual(actpp_module_count, 2)
|
|
|
|
actpp_module_count = 0
|
|
for name, module in m.named_modules():
|
|
if isinstance(module, actpp_module_class):
|
|
actpp_module_count += 1
|
|
self.assertEqual(actpp_module_count, 1)
|
|
|
|
m_copy = copy.deepcopy(m)
|
|
m = convert_fx(m)
|
|
m_reference = convert_to_reference_fx(m_copy)
|
|
|
|
# checks for non-reference quantized model
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(torch.nn.AvgPool2d),
|
|
ns.call_method("dequantize"),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, expected_node_list=node_list)
|
|
|
|
# checks for reference quantized model, for copy nodes we'll have
|
|
# dequant - copy_node - quant patterns which will be fused later
|
|
# in the backend lowering step
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 2
|
|
}
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(torch.nn.AvgPool2d),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
]
|
|
self.checkGraphModuleNodes(m_reference, expected_node_occurrence=node_occurrence, expected_node_list=node_list)
|
|
|
|
def test_linear_qint8_activation(self):
|
|
"""Test support for qint8 activation in reference pattern
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 2, 2, 2)
|
|
self.linear = torch.nn.Linear(8, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.rand(2, 1, 5, 5),)
|
|
m = prepare_fx(
|
|
m,
|
|
{"": torch.ao.quantization.QConfig(
|
|
activation=torch.ao.quantization.HistogramObserver.with_args(
|
|
qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
|
|
), weight=torch.ao.quantization.default_per_channel_weight_observer)},
|
|
example_inputs=example_inputs)
|
|
m = convert_to_reference_fx(m)
|
|
m(*example_inputs)
|
|
|
|
def test_preserve_tuple(self):
|
|
""" Test tuple input type is preserved
|
|
"""
|
|
|
|
class LSTM(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lstm = nn.LSTM(50, 50, 1)
|
|
|
|
def forward(self, inputs: torch.Tensor, state: list[torch.Tensor]):
|
|
h = state[0]
|
|
c = state[1]
|
|
return self.lstm(inputs, (h, c))
|
|
|
|
m = LSTM().eval()
|
|
example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50))
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
# make sure the arg[1] of lstm module is a tuple
|
|
for n in m.graph.nodes:
|
|
if n.target == "lstm":
|
|
self.assertEqual(type(n.args[1]), tuple)
|
|
|
|
def _test_static_lstm_helper(self, model, prepare_node_occurrence, convert_node_occurrence):
|
|
"""
|
|
Helper method to validate the graph of a model with static LSTM.
|
|
"""
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
prepare_custom_config = PrepareCustomConfig() \
|
|
.set_float_to_observed_mapping(torch.nn.LSTM, torch.ao.nn.quantizable.LSTM)
|
|
convert_custom_config = ConvertCustomConfig() \
|
|
.set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, torch.ao.nn.quantized.LSTM)
|
|
example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
|
|
|
|
model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config)
|
|
self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence)
|
|
model(*example_inputs)
|
|
|
|
model = convert_fx(model, convert_custom_config=convert_custom_config)
|
|
self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence)
|
|
model(*example_inputs)
|
|
|
|
def test_static_lstm(self):
|
|
"""
|
|
Test statically quantized custom module LSTM followed by ops that consume individual
|
|
tensors of the output tuple.
|
|
"""
|
|
class MyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lstm = nn.LSTM(50, 50, 1)
|
|
self.linear1 = nn.Linear(50, 10)
|
|
self.linear2 = nn.Linear(50, 10)
|
|
self.linear3 = nn.Linear(50, 10)
|
|
|
|
def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
|
|
(out, (h0_out, c0_out)) = self.lstm(inputs, (h0, c0))
|
|
out = self.linear1(out)
|
|
h0_out = self.linear2(h0_out)
|
|
c0_out = self.linear3(c0_out)
|
|
return (out, (h0_out, c0_out))
|
|
|
|
m = MyModel()
|
|
prepare_node_occurrence = {
|
|
ns.call_module(torch.ao.nn.quantizable.LSTM): 1,
|
|
}
|
|
convert_node_occurrence = {
|
|
ns.call_module(torch.ao.nn.quantized.LSTM): 1,
|
|
ns.call_function(torch.quantize_per_tensor): 3,
|
|
# lstm[0].dequantize()
|
|
# lstm[1][0].dequantize()
|
|
# lstm[1][1].dequantize()
|
|
ns.call_method("dequantize"): 3,
|
|
# lstm[0], lstm[1], lstm[1][0], lstm[1][1]
|
|
ns.call_function(operator.getitem): 4,
|
|
# No tuples are consumed
|
|
ns.call_function(tuple): 0,
|
|
}
|
|
self._test_static_lstm_helper(m, prepare_node_occurrence, convert_node_occurrence)
|
|
|
|
def test_static_lstm_consume_tuple(self):
|
|
"""
|
|
Test statically quantized custom module LSTM followed by a module that consumes the
|
|
output tuple, either as a whole or part of it.
|
|
"""
|
|
class ModuleAfterLSTM(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.identity = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
return self.identity(x)
|
|
|
|
class ConsumeWholeTuple(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lstm = nn.LSTM(50, 50, 1)
|
|
self.module_after_lstm = ModuleAfterLSTM()
|
|
|
|
def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
|
|
x = self.lstm(inputs, (h0, c0))
|
|
x = self.module_after_lstm(x) # consume tuple (output, (hidden0, hidden1))
|
|
return x
|
|
|
|
class ConsumeHiddenTuple(ConsumeWholeTuple):
|
|
def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
|
|
x = self.lstm(inputs, (h0, c0))
|
|
x = self.module_after_lstm(x[1]) # consume tuple (hidden0, hidden1)
|
|
return x
|
|
|
|
# Test consuming the whole tuple (output, (hidden0, hidden1))
|
|
m1 = ConsumeWholeTuple()
|
|
prepare_node_occurrence = {
|
|
ns.call_module(torch.ao.nn.quantizable.LSTM): 1,
|
|
}
|
|
convert_node_occurrence1 = {
|
|
ns.call_module(torch.ao.nn.quantized.LSTM): 1,
|
|
ns.call_function(torch.quantize_per_tensor): 3,
|
|
# lstm[0].dequantize()
|
|
# lstm[1][0].dequantize()
|
|
# lstm[1][1].dequantize()
|
|
ns.call_method("dequantize"): 3,
|
|
# lstm[0], lstm[1], lstm[1][0], lstm[1][1]
|
|
ns.call_function(operator.getitem): 4,
|
|
# tuple(output_dq, tuple(hidden0_dq, hidden1_dq))
|
|
ns.call_function(tuple): 2,
|
|
}
|
|
self._test_static_lstm_helper(m1, prepare_node_occurrence, convert_node_occurrence1)
|
|
|
|
# Test consuming just the hidden tuple (hidden0, hidden1)
|
|
m2 = ConsumeHiddenTuple()
|
|
convert_node_occurrence2 = {
|
|
ns.call_module(torch.ao.nn.quantized.LSTM): 1,
|
|
ns.call_function(torch.quantize_per_tensor): 3,
|
|
# lstm[1][0].dequantize()
|
|
# lstm[1][1].dequantize()
|
|
ns.call_method("dequantize"): 2,
|
|
# lstm[1], lstm[1][0], lstm[1][1]
|
|
ns.call_function(operator.getitem): 3,
|
|
# tuple(hidden0_dq, hidden1_dq)
|
|
ns.call_function(tuple): 1,
|
|
}
|
|
self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2)
|
|
|
|
def test_static_lstm_with_custom_fixed_qparams(self):
|
|
"""
|
|
Test statically quantized LSTM with custom fixed qparams assigned to each of the
|
|
inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM`
|
|
and use the child class in the custom module mapping.
|
|
"""
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.my_lstm = torch.nn.LSTM(50, 50, 1)
|
|
|
|
def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
|
|
x = self.my_lstm(inputs, (h0, c0))
|
|
return x
|
|
|
|
# Construct a BackendConfig that supports qint32 for certain ops
|
|
# TODO: build a BackendConfig from scratch instead of modifying an existing one
|
|
qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
|
|
my_backend_config = get_qnnpack_backend_config()
|
|
for config in my_backend_config.configs:
|
|
if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
|
|
config.add_dtype_config(qint32_dtype_config)
|
|
|
|
class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
|
|
"""
|
|
Example of user provided LSTM implementation that assigns fixed qparams
|
|
to the inner ops.
|
|
"""
|
|
@classmethod
|
|
def from_float(cls, float_lstm):
|
|
assert isinstance(float_lstm, cls._FLOAT_MODULE)
|
|
# uint16, [-16, 16)
|
|
linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32)
|
|
# uint16, [0, 1)
|
|
sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32)
|
|
# uint16, [-1, 1)
|
|
tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32)
|
|
# int16, [-16, 16)
|
|
cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32)
|
|
# uint8, [-1, 1)
|
|
hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8)
|
|
example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50)))
|
|
return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
|
|
float_lstm=float_lstm,
|
|
example_inputs=example_inputs,
|
|
backend_config=my_backend_config,
|
|
linear_output_obs_ctr=linear_output_obs_ctr,
|
|
sigmoid_obs_ctr=sigmoid_obs_ctr,
|
|
tanh_obs_ctr=tanh_obs_ctr,
|
|
cell_state_obs_ctr=cell_state_obs_ctr,
|
|
hidden_state_obs_ctr=hidden_state_obs_ctr,
|
|
)
|
|
|
|
class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
|
|
"""
|
|
Example of user provided LSTM implementation that produces a reference
|
|
quantized module from a `UserObservedLSTM`.
|
|
"""
|
|
@classmethod
|
|
def from_observed(cls, observed_lstm):
|
|
assert isinstance(observed_lstm, cls._FLOAT_MODULE)
|
|
return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module(
|
|
observed_lstm=observed_lstm,
|
|
backend_config=my_backend_config,
|
|
)
|
|
|
|
# FX graph mode quantization
|
|
m = MyModel()
|
|
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
|
|
example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
|
|
prepare_custom_config = PrepareCustomConfig() \
|
|
.set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM)
|
|
convert_custom_config = ConvertCustomConfig() \
|
|
.set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM)
|
|
prepared = prepare_fx(
|
|
m,
|
|
qconfig_mapping,
|
|
example_inputs,
|
|
prepare_custom_config,
|
|
backend_config=my_backend_config,
|
|
)
|
|
prepared(*example_inputs)
|
|
converted = convert_fx(
|
|
prepared,
|
|
convert_custom_config,
|
|
backend_config=my_backend_config,
|
|
)
|
|
converted(*example_inputs)
|
|
|
|
# Find the patterns [dq - op - q_to_specific_dtype] in the graph and
|
|
# verify that qparams and dtypes are set correctly in the quantize ops
|
|
node_name_to_expected_quantize_args = {
|
|
"igates": (None, None, torch.quint8),
|
|
"hgates": (None, None, torch.quint8),
|
|
"add": (2 ** -11, 2 ** 15, torch.qint32), # gates.add
|
|
"input_gate": (2 ** -16, 0, torch.qint32),
|
|
"forget_gate": (2 ** -16, 0, torch.qint32),
|
|
"cell_gate": (2 ** -15, 2 ** 15, torch.qint32),
|
|
"output_gate": (2 ** -16, 0, torch.qint32),
|
|
"mul": (2 ** -11, 0, torch.qint32), # fgate_cx.mul
|
|
"mul_1": (2 ** -11, 0, torch.qint32), # igate_cgate.mul
|
|
"add_1": (2 ** -11, 0, torch.qint32), # fgate_cx_igate_cgate.add
|
|
"mul_2": (2 ** -7, 2 ** 7, torch.quint8), # ogate_cy.mul
|
|
}
|
|
cell = converted.my_lstm.layers.get_submodule("0").layer_fw.cell
|
|
matched_names = set()
|
|
for node in cell.graph.nodes:
|
|
if node.name not in node_name_to_expected_quantize_args:
|
|
continue
|
|
matched_names.add(node.name)
|
|
# Match preceding dequantize
|
|
self.assertTrue(all(arg.target == "dequantize" for arg in node.args))
|
|
# Match following quantize with the specific qparams and dtypes
|
|
expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name]
|
|
for user in node.users.keys():
|
|
self.assertEqual(user.target, torch.quantize_per_tensor)
|
|
if expected_scale is not None:
|
|
self.assertEqual(getattr(cell, user.args[1].target), expected_scale)
|
|
if expected_zp is not None:
|
|
self.assertEqual(getattr(cell, user.args[2].target), expected_zp)
|
|
self.assertEqual(user.args[-1], expected_dtype)
|
|
# Ensure all patterns were matched
|
|
self.assertEqual(matched_names, set(node_name_to_expected_quantize_args.keys()))
|
|
|
|
def test_reroute_tuple_getitem_patterns(self):
|
|
"""
|
|
The following graph should redirect the output to `b`. After the transformation,
|
|
all other nodes, including the inputs `a` and `c`, are no longer needed.
|
|
|
|
a b c
|
|
| \\ /
|
|
\\ tuple
|
|
\\ /
|
|
tuple
|
|
/ \\
|
|
/ \\
|
|
| \\
|
|
| \\
|
|
| \\
|
|
getitem0 getitem1
|
|
| / \\
|
|
| getitem0 getitem1
|
|
| \\ /
|
|
\\ tuple
|
|
\\ /
|
|
\\ /
|
|
tuple
|
|
|
|
|
getitem1
|
|
|
|
|
getitem0
|
|
|
|
|
output
|
|
"""
|
|
# Construct graph manually because symbolic_trace does not insert tuple and getitem nodes
|
|
graph = torch.fx.Graph()
|
|
a = graph.create_node("placeholder", "a")
|
|
b = graph.create_node("placeholder", "b")
|
|
c = graph.create_node("placeholder", "c")
|
|
bc = graph.call_function(tuple, args=([b, c],))
|
|
abc = graph.call_function(tuple, args=([a, bc],))
|
|
|
|
# Break down tuple and reconstruct it again
|
|
a2 = graph.call_function(operator.getitem, args=(abc, 0))
|
|
bc2 = graph.call_function(operator.getitem, args=(abc, 1))
|
|
b2 = graph.call_function(operator.getitem, args=(bc2, 0))
|
|
c2 = graph.call_function(operator.getitem, args=(bc2, 1))
|
|
bc3 = graph.call_function(tuple, args=([b2, c2],))
|
|
abc2 = graph.call_function(tuple, args=([a2, bc3],))
|
|
|
|
# Output tuple[1][0]
|
|
bc4 = graph.call_function(operator.getitem, args=(abc2, 1))
|
|
b3 = graph.call_function(operator.getitem, args=(bc4, 0))
|
|
output = graph.output(b3)
|
|
|
|
# Do reroute
|
|
_reroute_tuple_getitem_pattern(graph)
|
|
|
|
# Assert that output reroutes to `b` directly, and all other nodes can be removed
|
|
output_ancestors = []
|
|
def gather_ancestors(current_node): # noqa: E306
|
|
for arg in current_node.args:
|
|
output_ancestors.append(arg)
|
|
gather_ancestors(arg)
|
|
gather_ancestors(output)
|
|
self.assertEqual(output_ancestors, [b])
|
|
self.assertEqual(output.args[0], b)
|
|
|
|
def test_relu_lowering(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.relu(x)
|
|
|
|
m = M().eval()
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),))
|
|
m_copy = copy.deepcopy(m)
|
|
m = convert_fx(m)
|
|
m_ref = convert_to_reference_fx(m_copy)
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
node_occurrence_ref = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 2
|
|
}
|
|
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_dynamic_with_fusion(self):
|
|
"""
|
|
Tests that dynamic quantization APIs work with Linear + Relu fusion
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
class LinearRelu(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return self.relu(x)
|
|
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu())
|
|
self.mods2 = Linear()
|
|
self.relu = F.relu
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = self.mods2(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
dynamic_quantized_ops = {
|
|
float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16,
|
|
default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic
|
|
}
|
|
for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
|
|
model = M().eval()
|
|
qconfig_dict = {
|
|
"": qconfig
|
|
}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
node_list = [
|
|
ns.call_module(nniqd.LinearReLU),
|
|
ns.call_module(nniqd.LinearReLU),
|
|
ns.call_function(dynamic_quantized_ops[qconfig]),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_dynamic_with_fusion_multiple_uses(self):
|
|
"""
|
|
Tests that dynamic quantization APIs work with Linear + Relu fusion
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
class LinearRelu(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return self.relu(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear_relu = LinearRelu()
|
|
|
|
def forward(self, x):
|
|
x = self.linear_relu(x)
|
|
x = self.linear_relu(x)
|
|
return x
|
|
|
|
for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
|
|
model = M().eval()
|
|
qconfig_dict = {
|
|
"": qconfig
|
|
}
|
|
example_inputs = (torch.randn(5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
node_list = [
|
|
ns.call_module(nniqd.LinearReLU),
|
|
ns.call_module(nniqd.LinearReLU),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_dynamic_linear_input_multiple_use(self):
|
|
"""
|
|
Tests input for dynamic linear being used by multiple ops
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
class LinearRelu(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return self.relu(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mod1 = LinearRelu()
|
|
self.mod2 = LinearRelu()
|
|
|
|
def forward(self, x):
|
|
y1 = self.mod1(x)
|
|
y2 = self.mod2(x)
|
|
return y1 + y2
|
|
|
|
for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
|
|
model = M().eval()
|
|
qconfig_dict = {
|
|
"": qconfig
|
|
}
|
|
example_inputs = (torch.rand(5, 5, 5),)
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
node_list = [
|
|
ns.call_module(nniqd.LinearReLU),
|
|
ns.call_module(nniqd.LinearReLU),
|
|
]
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
def test_ref_linear_module(self):
|
|
""" Make sure the numerics for models with ref linear module
|
|
matches models with fbgemm/qnnpack module
|
|
"""
|
|
class M1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class M2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 5)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.linear(x))
|
|
|
|
for M in [M1, M2]:
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(5, 10),)
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
m_copy = copy.deepcopy(m)
|
|
m = convert_fx(m)
|
|
m_ref = convert_to_reference_fx(m_copy)
|
|
result = m(*example_inputs)
|
|
result_ref = m_ref(*example_inputs)
|
|
self.assertTrue(torch.equal(result, result_ref))
|
|
|
|
def test_ref_conv_module(self):
|
|
""" Make sure the numerics for models with ref conv module
|
|
matches models with fbgemm/qnnpack module
|
|
"""
|
|
convs = {
|
|
1: nn.Conv1d,
|
|
2: nn.Conv2d,
|
|
3: nn.Conv3d,
|
|
}
|
|
|
|
class M1(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = convs[dim](3, 3, 3)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class M2(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = convs[dim](3, 3, 3)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.conv(x))
|
|
|
|
for dim, M in itertools.product([1, 2, 3], [M1, M2]):
|
|
m = M(dim).eval()
|
|
data = self.img_data_dict[dim][0][0]
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,))
|
|
m_copy = copy.deepcopy(m)
|
|
m = convert_fx(m)
|
|
m_ref = convert_to_reference_fx(m_copy)
|
|
result = m(data)
|
|
result_ref = m_ref(data)
|
|
self.assertTrue(torch.equal(result, result_ref))
|
|
|
|
def test_sub_scalar(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 1
|
|
x = x - 1
|
|
x = x + 3
|
|
x = x - 4
|
|
return x
|
|
|
|
m = M().eval()
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),))
|
|
m = convert_fx(m)
|
|
occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 2
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence)
|
|
|
|
def test_observer_fqn(self):
|
|
"""
|
|
Test to make sure the observer FQN is based on the quantizable op/module that it is observing
|
|
and uses the modules FQN to determine the observer name.
|
|
"""
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
self.mods2 = Linear()
|
|
self.mods3 = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = torch.add(x, 4)
|
|
x = self.mods2(x)
|
|
y = torch.add(x, 2)
|
|
z = torch.mul(x, 5)
|
|
a = self.mods3(y)
|
|
return a, z
|
|
|
|
model = M().eval()
|
|
|
|
prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5)))
|
|
name_list = []
|
|
for name, mod in prepared.named_modules():
|
|
if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver):
|
|
name_list.append(name)
|
|
expected_name_list = ['activation_post_process_0',
|
|
'activation_post_process_1',
|
|
'activation_post_process_2',
|
|
'activation_post_process_3',
|
|
'activation_post_process_4',
|
|
'activation_post_process_6',
|
|
'activation_post_process_7',
|
|
'activation_post_process_10']
|
|
assert name_list == expected_name_list
|
|
|
|
def test_conv_lowering(self):
|
|
convs = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
|
|
qconvs = {1: nn.quantized.Conv1d, 2: nn.quantized.Conv2d, 3: nn.quantized.Conv3d}
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = convs[dim](3, 3, 3)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
for dim in range(1, len(convs) + 1):
|
|
m = M(dim).eval()
|
|
data = self.img_data_dict[dim][0][0]
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,))
|
|
m_ref = copy.deepcopy(m)
|
|
m_ref = convert_to_reference_fx(m_ref)
|
|
m = convert_fx(m)
|
|
out_ref = m_ref(data)
|
|
out = m(data)
|
|
# check that reference pattern for quantized conv module is fused
|
|
expected_node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_module(qconvs[dim]): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=expected_node_occurrence)
|
|
# checking result match
|
|
self.assertTrue(torch.equal(out_ref, out))
|
|
|
|
def test_convert_qconfig_mapping(self):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(
|
|
Linear(),
|
|
Linear()
|
|
)
|
|
self.mods3 = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = torch.add(x, 4)
|
|
z = torch.mul(x, 5)
|
|
x = self.mods3(z)
|
|
return x
|
|
|
|
model = M().train()
|
|
|
|
for check in ["module_name", "object_type"]:
|
|
qconfig_dict = {"": None,
|
|
"object_type": [
|
|
(nn.functional.linear, get_default_qat_qconfig("fbgemm")),
|
|
(torch.add, get_default_qat_qconfig("fbgemm")),
|
|
(nn.Linear, get_default_qat_qconfig("fbgemm")),
|
|
],
|
|
}
|
|
example_inputs = (torch.rand(5, 5),)
|
|
prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
prepared(*example_inputs)
|
|
if check == "module_name":
|
|
convert_qconfig_dict = {"": None,
|
|
"object_type": [
|
|
(nn.functional.linear, get_default_qat_qconfig("fbgemm")),
|
|
(torch.add, get_default_qat_qconfig("fbgemm")),
|
|
(nn.Linear, get_default_qat_qconfig("fbgemm")),
|
|
],
|
|
"module_name": [("mods1.0", None)]}
|
|
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_function(torch.nn.functional.linear): 1,
|
|
ns.call_function(torch.ops.quantized.linear): 1,
|
|
ns.call_function(torch.ops.quantized.add): 1,
|
|
ns.call_method("dequantize"): 2
|
|
}
|
|
order_check = [
|
|
ns.call_function(torch.nn.functional.linear),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.linear),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Linear),
|
|
ns.call_method("dequantize"),
|
|
]
|
|
elif check == "object_type":
|
|
convert_qconfig_dict = {"": None,
|
|
"object_type": [
|
|
(nn.functional.linear, get_default_qat_qconfig("fbgemm")),
|
|
(torch.add, get_default_qat_qconfig("fbgemm")),
|
|
(nn.Linear, None),
|
|
]}
|
|
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_function(torch.ops.quantized.linear): 2,
|
|
ns.call_function(torch.ops.quantized.add): 1,
|
|
ns.call_function(torch.mul): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
order_check = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.linear),
|
|
ns.call_function(torch.ops.quantized.linear),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(torch.mul),
|
|
ns.call_module(nn.Linear),
|
|
]
|
|
|
|
converted = convert_fx(prepared, qconfig_mapping=convert_qconfig_dict)
|
|
converted(torch.rand(5, 5))
|
|
self.checkGraphModuleNodes(
|
|
converted,
|
|
expected_node_occurrence=node_occurrence,
|
|
expected_node_list=order_check)
|
|
|
|
def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2):
|
|
self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr)
|
|
|
|
def test_register_patterns(self):
|
|
def cleanUp():
|
|
del _DEFAULT_FUSION_PATTERNS["dummy_fusion"]
|
|
del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"]
|
|
del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"]
|
|
del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"]
|
|
del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"]
|
|
del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"]
|
|
del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"]
|
|
del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"]
|
|
self.addCleanup(cleanUp)
|
|
|
|
@_register_fusion_pattern("dummy_fusion")
|
|
class DummyFusion:
|
|
pass
|
|
|
|
@_register_quant_pattern("dummy_quant")
|
|
class DummyQuant:
|
|
pass
|
|
|
|
@_register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer)
|
|
class DummyQuant2:
|
|
pass
|
|
|
|
@_register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer)
|
|
class DummyQuant3:
|
|
pass
|
|
|
|
self.assertEqual(_DEFAULT_FUSION_PATTERNS["dummy_fusion"], DummyFusion)
|
|
self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant)
|
|
self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2)
|
|
self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3)
|
|
self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer)
|
|
self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer)
|
|
self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"],
|
|
default_fixed_qparams_range_0to1_fake_quant)
|
|
self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"],
|
|
default_fixed_qparams_range_neg1to1_fake_quant)
|
|
output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True)
|
|
output_observer_map = get_default_output_activation_post_process_map(is_training=False)
|
|
self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer)
|
|
self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"),
|
|
default_fixed_qparams_range_neg1to1_fake_quant)
|
|
|
|
|
|
|
|
def test_reuse_input_qconfig(self):
|
|
class M1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = x.reshape()
|
|
return x
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x.reshape()
|
|
return x
|
|
|
|
options = itertools.product([M1, M2], [True, False])
|
|
for M, is_qat in options:
|
|
m = M1().eval()
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method("reshape"),
|
|
ns.call_method("dequantize"),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=node_list)
|
|
|
|
m = M2().eval()
|
|
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 0,
|
|
ns.call_method("dequnatize"): 0,
|
|
}
|
|
node_list = [
|
|
ns.call_method("reshape"),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=node_occurrence,
|
|
expected_node_list=node_list)
|
|
|
|
def test_stack_trace_preserved_linear(self):
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
|
|
|
|
found_stack_trace = False
|
|
for n in mp.graph.nodes:
|
|
if n.op == 'call_module' and n.target == 'linear':
|
|
found_stack_trace = n.stack_trace is not None
|
|
break
|
|
self.assertTrue(found_stack_trace)
|
|
|
|
# test reference model
|
|
mq = convert_to_reference_fx(copy.deepcopy(mp))
|
|
found_stack_trace = False
|
|
for n in mq.graph.nodes:
|
|
if n.op == 'call_module' and n.target == 'linear':
|
|
found_stack_trace = n.stack_trace is not None
|
|
break
|
|
self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: True")
|
|
|
|
# test quantized model
|
|
mq = convert_fx(mp)
|
|
found_stack_trace = False
|
|
for n in mq.graph.nodes:
|
|
if n.op == 'call_module' and n.target == 'linear':
|
|
found_stack_trace = n.stack_trace is not None
|
|
break
|
|
self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: False")
|
|
|
|
def test_qat_skip_untraced(self):
|
|
class UnTraceableModuleClass(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class UnTraceableModuleName(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.untraceable_module_class = UnTraceableModuleClass()
|
|
self.untraceable_module_name = UnTraceableModuleClass()
|
|
|
|
def forward(self, x):
|
|
x = self.untraceable_module_class(x)
|
|
x = self.untraceable_module_name(x)
|
|
return x
|
|
|
|
mod = M()
|
|
|
|
qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig()}
|
|
prepare_custom_config_dict = {
|
|
"non_traceable_module_class": [UnTraceableModuleClass],
|
|
"non_traceable_module_name": ["untraceable_module_name"],
|
|
}
|
|
example_inputs = (torch.randn(2, 2),)
|
|
mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx(
|
|
mod.train(), qconfig_dict, example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config_dict
|
|
)
|
|
mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx(
|
|
mod.train(), qconfig_dict, example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config_dict
|
|
)
|
|
self.assertTrue(
|
|
isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear)
|
|
)
|
|
self.assertTrue(
|
|
isinstance(mod_prep.untraceable_module_name.linear, torch.nn.Linear)
|
|
)
|
|
self.assertTrue(
|
|
type(mod_prep.untraceable_module_class.linear)
|
|
is not torch.ao.nn.qat.modules.linear.Linear,
|
|
"prepare_qat_fx shold not convert anything inside untraced module classes",
|
|
)
|
|
self.assertTrue(
|
|
type(mod_prep.untraceable_module_name.linear)
|
|
is not torch.ao.nn.qat.modules.linear.Linear,
|
|
"prepare_qat_fx shold not convert anything inside modules named in untraced_module_names",
|
|
)
|
|
|
|
def test_qconfig_dict_setup(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.Conv1d = torch.nn.Conv1d(1, 1, 1)
|
|
self.Conv2d = torch.nn.Conv2d(1, 1, 1)
|
|
self.Conv3d = torch.nn.Conv3d(1, 1, 1)
|
|
self.ConvTranspose1d = torch.nn.ConvTranspose1d(1, 1, 1)
|
|
self.ConvTranspose2d = torch.nn.ConvTranspose2d(1, 1, 1)
|
|
self.ConvTranspose3d = torch.nn.ConvTranspose3d(1, 1, 1)
|
|
self.Linear = torch.nn.Linear(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.Conv1d(x)
|
|
x = self.Conv2d(x)
|
|
x = self.Conv3d(x)
|
|
x = self.ConvTranspose1d(x)
|
|
x = self.ConvTranspose2d(x)
|
|
x = self.ConvTranspose3d(x)
|
|
x = self.Linear(x)
|
|
x = torch.nn.functional.conv1d(x, torch.rand(2, 2))
|
|
x = torch.nn.functional.conv2d(x, torch.rand(2, 2))
|
|
x = torch.nn.functional.conv3d(x, torch.rand(2, 2))
|
|
x = torch.nn.functional.linear(x, torch.rand(2, 2))
|
|
return x
|
|
|
|
backends = ["qnnpack", "fbgemm"]
|
|
for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
|
|
for backend in backends:
|
|
m = M().eval()
|
|
qconfig_dict = func(backend)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1)))
|
|
for name, mod in m.named_modules():
|
|
if _is_activation_post_process(mod) and mod.dtype == torch.quint8:
|
|
if backend == "fbgemm":
|
|
lower_bnd = 0
|
|
upper_bnd = 127
|
|
else:
|
|
lower_bnd = 0
|
|
upper_bnd = 255
|
|
if issubclass(type(mod), FakeQuantize):
|
|
self.assertEqual(mod.activation_post_process.quant_min, lower_bnd)
|
|
self.assertEqual(mod.activation_post_process.quant_max, upper_bnd)
|
|
else:
|
|
self.assertEqual(mod.quant_min, lower_bnd)
|
|
self.assertEqual(mod.quant_max, upper_bnd)
|
|
|
|
def test_prepare_mode(self):
|
|
class LinearModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
def _test(prepare_fn, qconfig_dict):
|
|
m = LinearModel()
|
|
m1 = copy.deepcopy(m)
|
|
m1.train()
|
|
example_inputs = (torch.randn(1, 5),)
|
|
prepare_fn(m1, qconfig_dict, example_inputs=example_inputs)
|
|
m2 = copy.deepcopy(m)
|
|
m2.eval()
|
|
prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
|
|
|
|
# Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
|
|
_test(prepare_fx, get_default_qconfig_mapping())
|
|
_test(prepare_qat_fx, get_default_qat_qconfig_mapping())
|
|
|
|
def _validate_qconfig_against_backend_config_constraints(
|
|
self,
|
|
model: torch.nn.Module,
|
|
qconfig: QConfig,
|
|
backend_config: BackendConfig,
|
|
satisfies_constraints: bool,
|
|
qconfig_name: Optional[str] = None):
|
|
"""
|
|
Helper method to validate whether `qconfig` satisfies the constraints specified in `backend_config`.
|
|
"""
|
|
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
|
|
example_inputs = (torch.rand((1, 30), dtype=torch.float),)
|
|
model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
|
|
model(*example_inputs)
|
|
model = convert_fx(model, backend_config=backend_config)
|
|
if satisfies_constraints:
|
|
expected_node_occurrence = {
|
|
ns.call_module(torch.ao.nn.quantized.Linear) : 1,
|
|
ns.call_module(torch.nn.Linear) : 0,
|
|
}
|
|
else:
|
|
expected_node_occurrence = {
|
|
ns.call_module(torch.ao.nn.quantized.Linear) : 0,
|
|
ns.call_module(torch.nn.Linear) : 1,
|
|
}
|
|
try:
|
|
self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
|
|
except AssertionError as e:
|
|
if qconfig_name is not None:
|
|
print(f"ERROR: Validation for QConfig '{qconfig_name}' failed")
|
|
raise e
|
|
|
|
def test_backend_config_quantization_range(self):
|
|
"""
|
|
Check that quantization ranges specified through the BackendConfig are reflected in
|
|
the observers inserted into the model.
|
|
"""
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
dtype_config = DTypeConfig(
|
|
input_dtype=DTypeWithConstraints(
|
|
dtype=torch.quint8,
|
|
quant_min_lower_bound=0,
|
|
quant_max_upper_bound=31,
|
|
),
|
|
output_dtype=DTypeWithConstraints(
|
|
dtype=torch.quint8,
|
|
quant_min_lower_bound=0,
|
|
quant_max_upper_bound=31,
|
|
),
|
|
weight_dtype=DTypeWithConstraints(
|
|
dtype=torch.qint8,
|
|
quant_min_lower_bound=-64,
|
|
quant_max_upper_bound=63,
|
|
),
|
|
bias_dtype=torch.float,
|
|
)
|
|
backend_config = BackendConfig() \
|
|
.set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear)
|
|
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128
|
|
.add_dtype_config(dtype_config)
|
|
.set_root_module(torch.nn.Linear)
|
|
.set_reference_quantized_module(nnqr.Linear))
|
|
|
|
def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool):
|
|
self._validate_qconfig_against_backend_config_constraints(
|
|
MyModel(), qconfig, backend_config, satisfies_constraints)
|
|
|
|
# Case 1: QConfig ranges fit within backend ranges, OK
|
|
qconfig1 = QConfig(
|
|
activation=MinMaxObserver.with_args(quant_min=0, quant_max=15, dtype=torch.quint8),
|
|
weight=MinMaxObserver.with_args(quant_min=-32, quant_max=31, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
|
|
validate_qconfig(qconfig1, satisfies_constraints=True)
|
|
|
|
# Case 2: QConfig activation range falls outside backend range, should fail
|
|
qconfig2 = QConfig(
|
|
activation=MinMaxObserver.with_args(quant_min=0, quant_max=63, dtype=torch.quint8),
|
|
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
|
|
validate_qconfig(qconfig2, satisfies_constraints=False)
|
|
|
|
# Case 3: QConfig weight range falls outside backend range, should fail
|
|
qconfig3 = QConfig(
|
|
activation=MinMaxObserver.with_args(dtype=torch.quint8),
|
|
weight=MinMaxObserver.with_args(quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
|
|
validate_qconfig(qconfig3, satisfies_constraints=False)
|
|
|
|
# Case 4: QConfig doesn't specify range, should fail
|
|
qconfig4 = QConfig(activation=ReuseInputObserver, weight=ReuseInputObserver)
|
|
validate_qconfig(qconfig4, satisfies_constraints=False)
|
|
|
|
def test_backend_config_scale_min(self):
|
|
"""
|
|
Test QConfig eps validation against the BackendConfig's min scale value.
|
|
"""
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
dtype_config = DTypeConfig(
|
|
input_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12),
|
|
output_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12),
|
|
weight_dtype=DTypeWithConstraints(dtype=torch.qint8, scale_min_lower_bound=2 ** -12),
|
|
bias_dtype=torch.float,
|
|
)
|
|
|
|
backend_config = BackendConfig() \
|
|
.set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear)
|
|
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128
|
|
.add_dtype_config(dtype_config)
|
|
.set_root_module(torch.nn.Linear)
|
|
.set_reference_quantized_module(nnqr.Linear))
|
|
|
|
def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool):
|
|
self._validate_qconfig_against_backend_config_constraints(
|
|
MyModel(), qconfig, backend_config, satisfies_constraints)
|
|
|
|
# Case 1: QConfig min scale value == backend min scale value, OK
|
|
qconfig1 = QConfig(
|
|
activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -12),
|
|
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -12))
|
|
validate_qconfig(qconfig1, satisfies_constraints=True)
|
|
|
|
# Case 2: QConfig min scale value > backend min scale value, OK
|
|
qconfig2 = QConfig(
|
|
activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -10),
|
|
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -10))
|
|
validate_qconfig(qconfig2, satisfies_constraints=True)
|
|
|
|
# Case 3: QConfig activation min scale value < backend min scale value, should fail
|
|
qconfig3 = QConfig(
|
|
activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -14),
|
|
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
|
|
validate_qconfig(qconfig3, satisfies_constraints=False)
|
|
|
|
# Case 3: QConfig weight min scale value < backend min scale value, should fail
|
|
qconfig4 = QConfig(
|
|
activation=MinMaxObserver.with_args(dtype=torch.quint8),
|
|
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -14))
|
|
validate_qconfig(qconfig4, satisfies_constraints=False)
|
|
|
|
# Case 5: QConfig doesn't specify eps, should fail
|
|
qconfig5 = QConfig(
|
|
activation=FixedQParamsObserver.with_args(scale=1.0, zero_point=0),
|
|
weight=FixedQParamsObserver.with_args(scale=1.0, zero_point=0))
|
|
validate_qconfig(qconfig5, satisfies_constraints=False)
|
|
|
|
def test_qnnpack_backend_config(self):
|
|
"""
|
|
Test whether default QNNPACK QConfigs are compatible with the QNNPACK BackendConfig.
|
|
"""
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
all_qconfigs: list[tuple[QConfig, str]] = [
|
|
(get_default_qconfig("qnnpack", version=0), "default_qnnpack_qconfig_v0"),
|
|
(get_default_qat_qconfig("qnnpack", version=0), "default_qat_qnnpack_qconfig_v0"),
|
|
(get_default_qat_qconfig("qnnpack", version=1), "default_qat_qnnpack_qconfig_v1"),
|
|
(default_symmetric_qnnpack_qconfig, "default_symmetric_qnnpack_qconfig"),
|
|
(default_symmetric_qnnpack_qat_qconfig, "default_symmetric_qnnpack_qat_qconfig"),
|
|
# TODO: Test these QConfigs once they are fixed, see https://github.com/pytorch/pytorch/issues/85862
|
|
# (default_per_channel_symmetric_qnnpack_qconfig, "default_per_channel_symmetric_qnnpack_qconfig"),
|
|
# (default_per_channel_symmetric_qnnpack_qat_qconfig, "default_per_channel_symmetric_qnnpack_qat_qconfig"),
|
|
]
|
|
backend_config = get_qnnpack_backend_config()
|
|
for qconfig, qconfig_name in all_qconfigs:
|
|
self._validate_qconfig_against_backend_config_constraints(
|
|
MyModel(), qconfig, backend_config, satisfies_constraints=True, qconfig_name=qconfig_name)
|
|
|
|
def test_symmetric_qnnpack_qconfig_mapping(self):
|
|
"""
|
|
Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qconfig_mapping`
|
|
works with the QNNPACK BackendConfig.
|
|
"""
|
|
if "qnnpack" not in supported_qengines:
|
|
return
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
with override_quantized_engine("qnnpack"):
|
|
qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping()
|
|
example_inputs = (torch.rand((1, 30), dtype=torch.float),)
|
|
backend_config = get_qnnpack_backend_config()
|
|
model = MyModel()
|
|
model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
|
|
model(*example_inputs)
|
|
model = convert_fx(model, backend_config=backend_config)
|
|
expected_node_occurrence = {
|
|
ns.call_module(torch.ao.nn.quantized.Linear) : 1,
|
|
ns.call_module(torch.nn.Linear) : 0,
|
|
}
|
|
self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
|
|
model(*example_inputs)
|
|
|
|
def test_symmetric_qnnpack_qat_qconfig_mapping(self):
|
|
"""
|
|
Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qat_qconfig_mapping`
|
|
works with the QNNPACK BackendConfig.
|
|
"""
|
|
if "qnnpack" not in supported_qengines:
|
|
return
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
with override_quantized_engine("qnnpack"):
|
|
qconfig_mapping = _get_symmetric_qnnpack_qat_qconfig_mapping()
|
|
example_inputs = (torch.rand((1, 30), dtype=torch.float),)
|
|
backend_config = get_qnnpack_backend_config()
|
|
model = MyModel()
|
|
model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
|
|
model(*example_inputs)
|
|
model = convert_fx(model, backend_config=backend_config)
|
|
expected_node_occurrence = {
|
|
ns.call_module(torch.ao.nn.quantized.Linear) : 1,
|
|
ns.call_module(torch.nn.Linear) : 0,
|
|
}
|
|
self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
|
|
model(*example_inputs)
|
|
|
|
|
|
def test_get_executorch_backend_config(self):
|
|
from torch.ao.quantization.backend_config import get_executorch_backend_config
|
|
# make sure this runs
|
|
executorch_backend_config = get_executorch_backend_config()
|
|
|
|
def test_backend_config_check_for_weight_and_bias(self):
|
|
""" Test to make sure the backend_config check for weight and bias
|
|
runs when the qconfig is None for the ops with weight and bias
|
|
previously the error was not hit because we first check input, and
|
|
the check for weight and bias are skipped.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.tensor((5, 5))
|
|
self.bias = torch.tensor((5,))
|
|
|
|
def forward(self, x):
|
|
return torch.addmm(self.bias, x, self.weight)
|
|
|
|
m = M().eval()
|
|
qconfig_mapping = QConfigMapping()
|
|
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
|
|
weighted_op_quint8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.quint8,
|
|
output_dtype=torch.quint8,
|
|
weight_dtype=torch.qint8,
|
|
bias_dtype=torch.float,
|
|
)
|
|
dtype_configs = [weighted_op_quint8_dtype_config]
|
|
backend_pattern_config = BackendPatternConfig(torch.addmm) \
|
|
.set_observation_type(observation_type) \
|
|
.set_dtype_configs(dtype_configs) \
|
|
._set_input_type_to_index({"weight": 2, "bias": 0})
|
|
backend_config = BackendConfig() \
|
|
.set_backend_pattern_config(backend_pattern_config)
|
|
example_inputs = (torch.rand(1, 5),)
|
|
# make sure this runs
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
|
|
|
|
def test_get_default_qconfig_valid_backend(self):
|
|
""" Checks that AssertionError is raised when non expected backend input is specified
|
|
"""
|
|
invalid_backends = ["imaginary_backend", 3]
|
|
for invalid_backend in invalid_backends:
|
|
with self.assertRaisesRegex(AssertionError, "not supported"):
|
|
qconfig = get_default_qconfig(invalid_backend)
|
|
with self.assertRaisesRegex(AssertionError, "not supported"):
|
|
qconfig = get_default_qat_qconfig(invalid_backend)
|
|
with self.assertRaisesRegex(AssertionError, "not supported"):
|
|
qconfig_mapping = get_default_qconfig_mapping(invalid_backend)
|
|
with self.assertRaisesRegex(AssertionError, "not supported"):
|
|
qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend)
|
|
|
|
def test__convert_to_reference_decomposed_fx(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
m = M().eval()
|
|
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
|
|
example_inputs = (torch.randn(1, 5),)
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs)
|
|
m_ref = copy.deepcopy(m)
|
|
m_ref = convert_to_reference_fx(m_ref)
|
|
m = _convert_to_reference_decomposed_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence)
|
|
# make sure it runs
|
|
res_ref = m_ref(*example_inputs)
|
|
res = m(*example_inputs)
|
|
self.assertEqual(res, res_ref)
|
|
|
|
@skipIfNoQNNPACK
|
|
def test__convert_to_reference_decomposed_fx_dynamic_quant(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
# to avoid reduce_range
|
|
with override_quantized_engine("qnnpack"):
|
|
m = M().eval()
|
|
qconfig_mapping = get_default_qconfig_mapping("fbgemm") \
|
|
.set_object_type(torch.nn.Linear, default_dynamic_qconfig)
|
|
example_inputs = (torch.randn(1, 5),)
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs)
|
|
m(*example_inputs)
|
|
m_ref = copy.deepcopy(m)
|
|
m_ref = convert_to_reference_fx(m_ref)
|
|
m = _convert_to_reference_decomposed_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1,
|
|
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1,
|
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence)
|
|
# make sure it runs
|
|
res_ref = m_ref(*example_inputs)
|
|
res = m(*example_inputs)
|
|
self.assertEqual(res, res_ref)
|
|
|
|
def test__convert_to_reference_decomposed_fx_per_channel_quant(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, weight, bias):
|
|
return F.linear(x, weight, bias)
|
|
|
|
m = M().eval()
|
|
qconfig_mapping = get_default_qconfig_mapping("fbgemm") \
|
|
.set_object_type(F.linear, default_per_channel_qconfig)
|
|
example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,))
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs)
|
|
m(*example_inputs)
|
|
m_ref = copy.deepcopy(m)
|
|
m_ref = convert_to_reference_fx(m_ref)
|
|
m = _convert_to_reference_decomposed_fx(m)
|
|
expected_occurrence = {
|
|
# for input and output activations
|
|
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
|
# for weight
|
|
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
|
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence)
|
|
# make sure it runs
|
|
res_ref = m_ref(*example_inputs)
|
|
res = m(*example_inputs)
|
|
self.assertEqual(res, res_ref)
|
|
|
|
def test_change_backend_config_for_fixed_qparam_ops(self):
|
|
""" Making sure we can skip validation of qconfigs for fixedqparam ops based
|
|
on BackendConfig
|
|
"""
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.tanh = torch.nn.Tanh()
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = self.tanh(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
# we set a global default_qconfig, which will be ignored since the backend
|
|
# we defined doesn't support anything
|
|
# this is to make sure we don't validate the qconfig when BackendConfig does not
|
|
# have fixed qparam op related configurations
|
|
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
|
|
backend_config = BackendConfig()
|
|
# make sure this runs
|
|
model = prepare_fx(
|
|
model,
|
|
qconfig_mapping=qconfig_mapping,
|
|
example_inputs=(torch.randn(1, 2, 3, 4),),
|
|
backend_config=backend_config
|
|
)
|
|
|
|
def test_channel_shuffle_lowering(self):
|
|
# Three versions of channel shuffle
|
|
class M1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.op = torch.nn.ChannelShuffle(2)
|
|
|
|
def forward(self, x):
|
|
return self.op(x + x) + x
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.channel_shuffle(x + x, 2) + x
|
|
|
|
class M3(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.channel_shuffle(x + x, 2) + x
|
|
|
|
x = torch.randn(4, 4, 4, 4)
|
|
# torch.channel_shuffle is equivalent to torch.nn.functional.channel_shuffle
|
|
model_node_pairs = [
|
|
(M1().eval(), ns.call_module(torch.nn.ChannelShuffle)),
|
|
(M2().eval(), ns.call_function(torch.channel_shuffle)),
|
|
(M3().eval(), ns.call_function(torch.channel_shuffle))
|
|
]
|
|
for m, node in model_node_pairs:
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=(x,))
|
|
m_copy = copy.deepcopy(m)
|
|
m = convert_fx(m)
|
|
m_ref = convert_to_reference_fx(m_copy)
|
|
node_occurrence = {
|
|
node: 1,
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
node_occurrence_ref = {
|
|
node: 1,
|
|
ns.call_function(torch.quantize_per_tensor): 4,
|
|
ns.call_method("dequantize"): 4
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
|
|
|
|
def test_match_pattern_with_multiple_args(self):
|
|
""" Test that we can match a pattern that has multiple arguments
|
|
Pattern:
|
|
shape \
|
|
transpose (observed) -> reshape -> output (observed) ->
|
|
|
|
where `reshape` has two arguments
|
|
"""
|
|
|
|
def _get_pattern_configs():
|
|
backend_pattern_configs = []
|
|
observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
|
|
weighted_op_quint8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.quint8,
|
|
output_dtype=torch.quint8,
|
|
weight_dtype=torch.qint8,
|
|
bias_dtype=torch.float,
|
|
)
|
|
dtype_configs = [weighted_op_quint8_dtype_config]
|
|
|
|
def root_node_getter(node_pattern):
|
|
reshape, transpose, shape = node_pattern
|
|
return transpose
|
|
|
|
backend_pattern_configs.append(
|
|
BackendPatternConfig()
|
|
._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) # noqa: E131
|
|
.set_observation_type(observation_type)
|
|
.set_dtype_configs(dtype_configs)
|
|
._set_root_node_getter(root_node_getter)
|
|
)
|
|
return backend_pattern_configs
|
|
|
|
backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs())
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = torch.transpose(x, 0, 1)
|
|
x = torch.reshape(x, (-1,))
|
|
return x
|
|
|
|
m = M().eval()
|
|
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
|
|
node_occurrence = {
|
|
# one for input of the pattern and one for output of the pattern
|
|
ns.call_module(MinMaxObserver): 2
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def _test_linear_activation_fusion_lowering_helper(
|
|
self, module, example_inputs, qconfig_mapping,
|
|
backend_config, fused_module, root_module, activation_module):
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 1,
|
|
ns.call_module(fused_module): 1,
|
|
ns.call_module(root_module): 0,
|
|
ns.call_module(activation_module): 0,
|
|
}
|
|
node_occurrence_ref = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 2,
|
|
}
|
|
m = module.eval()
|
|
m = prepare_fx(m, qconfig_mapping,
|
|
example_inputs=example_inputs,
|
|
backend_config=backend_config)
|
|
m_copy = copy.deepcopy(m)
|
|
m = convert_fx(m, backend_config=backend_config)
|
|
m_ref = convert_to_reference_fx(m_copy)
|
|
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
|
|
m(*example_inputs)
|
|
|
|
@skipIfNoONEDNN
|
|
def test_linear_leaky_relu_lowering(self):
|
|
""" Test fusion and lowering of Linear - (bn -) LeakyReLU
|
|
by FX. For onednn backedn only.
|
|
"""
|
|
from torch.ao.quantization.backend_config import get_onednn_backend_config
|
|
qconfig_mapping = get_default_qconfig_mapping('onednn')
|
|
with override_quantized_engine('onednn'):
|
|
for with_bn in [True, False]:
|
|
m = LinearBnLeakyReluModel(with_bn)
|
|
self._test_linear_activation_fusion_lowering_helper(
|
|
m,
|
|
m.get_example_inputs(),
|
|
qconfig_mapping,
|
|
get_onednn_backend_config(),
|
|
nniq.LinearLeakyReLU,
|
|
nn.Linear,
|
|
nn.LeakyReLU)
|
|
|
|
@skipIfNoONEDNN
|
|
def test_linear_tanh_lowering(self):
|
|
""" Test fusion and lowering of Linear - Tanh
|
|
by FX. For onednn backedn only.
|
|
"""
|
|
from torch.ao.quantization.backend_config import get_onednn_backend_config
|
|
qconfig_mapping = get_default_qconfig_mapping('onednn')
|
|
# TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
|
|
# Need to be able to support fusion of ops with different qconfigs
|
|
# Since tanh must have 'fixed_qparams_qconfig' while linear should use
|
|
# the global qconfig, we need to set qconfigs for them manually here for
|
|
# fusion and cannot put such configs in onednn's default qconfig_mapping.
|
|
# Known issue:
|
|
# Cannot fuse linear - tanh and quantize standalone tanh at the same time.
|
|
qconfig = get_default_qconfig('onednn')
|
|
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig)
|
|
qconfig_mapping.set_object_type(torch.nn.Tanh, qconfig)
|
|
with override_quantized_engine('onednn'):
|
|
m = LinearTanhModel()
|
|
self._test_linear_activation_fusion_lowering_helper(
|
|
m,
|
|
m.get_example_inputs(),
|
|
qconfig_mapping,
|
|
get_onednn_backend_config(),
|
|
nniq.LinearTanh,
|
|
nn.Linear,
|
|
nn.Tanh)
|
|
|
|
@override_qengines
|
|
def test_linear_size_view(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_relu=False):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(16, 32)
|
|
self.relu = torch.nn.ReLU()
|
|
self.use_relu = use_relu
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
if self.use_relu:
|
|
x = self.relu(x)
|
|
return x.view(x.size(0), 1, 4, 8)
|
|
|
|
for use_relu in [False, True]:
|
|
model_fp32 = M(use_relu).eval()
|
|
qengine = torch.backends.quantized.engine
|
|
qconfig_mapping = get_default_qconfig_mapping(qengine)
|
|
x = torch.randn((5, 16))
|
|
model_fp32(x)
|
|
prepared_model = prepare_fx(model_fp32, qconfig_mapping, x)
|
|
prepared_model(x)
|
|
quantized_model = convert_fx(prepared_model)
|
|
node_occurrence = {
|
|
ns.call_module(nnq.Linear): 0 if use_relu else 1,
|
|
ns.call_module(nniq.LinearReLU): 1 if use_relu else 0,
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
|
|
|
|
@override_qengines
|
|
def test_linear_shape_view(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_relu=False):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(16, 32)
|
|
self.relu = torch.nn.ReLU()
|
|
self.use_relu = use_relu
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
if self.use_relu:
|
|
x = self.relu(x)
|
|
return x.view(x.shape[0], 1, 4, 8)
|
|
|
|
for use_relu in [False, True]:
|
|
model_fp32 = M(use_relu).eval()
|
|
qengine = torch.backends.quantized.engine
|
|
qconfig_mapping = get_default_qconfig_mapping(qengine)
|
|
x = torch.randn((5, 16))
|
|
model_fp32(x)
|
|
prepared_model = prepare_fx(model_fp32, qconfig_mapping, x)
|
|
prepared_model(x)
|
|
quantized_model = convert_fx(prepared_model)
|
|
node_occurrence = {
|
|
ns.call_module(nnq.Linear): 0 if use_relu else 1,
|
|
ns.call_module(nniq.LinearReLU): 1 if use_relu else 0,
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_mixed_dtypes(self):
|
|
"""
|
|
Test that multiple dtypes can be used in the same model for different layers,
|
|
and the dtypes will be converted correctly between the layers.
|
|
"""
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
self.linear2 = torch.nn.Linear(5, 5)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
self.tanh = torch.nn.Tanh()
|
|
self.float_functional = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = self.linear1(x) # qint32
|
|
x = self.linear2(x) # quint8
|
|
linear2 = x
|
|
x = self.sigmoid(x) # back to qint32
|
|
x = self.tanh(x) # back to quint8
|
|
x = self.float_functional.add(linear2, x) # adding two quint8's together
|
|
return x
|
|
|
|
def make_qconfig(scale, zp, dtype):
|
|
return QConfig(
|
|
activation=FixedQParamsObserver.with_args(scale=scale, zero_point=zp, dtype=dtype),
|
|
weight=torch.ao.quantization.default_weight_observer)
|
|
|
|
# Set up a QConfigMapping that specifies different qparams and dtypes for different layers
|
|
qconfig_mapping = QConfigMapping() \
|
|
.set_global(get_default_qconfig("qnnpack")) \
|
|
.set_module_name("linear1", make_qconfig(1234, 11, torch.qint32)) \
|
|
.set_module_name("linear2", make_qconfig(2345, 22, torch.quint8)) \
|
|
.set_object_type(torch.nn.Sigmoid, make_qconfig(3456, 33, torch.qint32)) \
|
|
.set_object_type(torch.nn.Tanh, make_qconfig(4567, 44, torch.quint8))
|
|
|
|
# Set up BackendConfig that supports the dtypes configured in the above QConfigMapping
|
|
weighted_op_qint32_dtype_config = DTypeConfig(
|
|
input_dtype=torch.qint32,
|
|
output_dtype=torch.qint32,
|
|
weight_dtype=torch.qint8,
|
|
bias_dtype=torch.float,
|
|
)
|
|
fixed_qparams_op_quint8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.quint8,
|
|
output_dtype=torch.quint8,
|
|
)
|
|
fixed_qparams_op_qint32_dtype_config = DTypeConfig(
|
|
input_dtype=torch.qint32,
|
|
output_dtype=torch.qint32,
|
|
)
|
|
backend_config = get_qnnpack_backend_config()
|
|
for config in backend_config.configs:
|
|
if config.pattern == torch.nn.Linear:
|
|
config.add_dtype_config(weighted_op_qint32_dtype_config)
|
|
elif config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh]:
|
|
config.add_dtype_config(fixed_qparams_op_quint8_dtype_config)
|
|
config.add_dtype_config(fixed_qparams_op_qint32_dtype_config)
|
|
|
|
# Produce the reference quantized model
|
|
m = MyModule()
|
|
example_inputs = (torch.rand(5, 5),)
|
|
prepared = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
|
|
prepared(*example_inputs) # calibrate
|
|
converted = convert_to_reference_fx(prepared, backend_config=backend_config)
|
|
converted(*example_inputs)
|
|
|
|
# Verify that the reference model is correct
|
|
#
|
|
# Reference model until add should be:
|
|
# fp32_input -> q_to_int32 -> [dq -> linear1_fp32 -> q_to_int32] -> dq ->
|
|
# q_to_uint8 -> [dq -> linear2_fp32 -> q_to_uint8] -> dq (linear2_dq) ->
|
|
# q_to_int32 -> [dq -> sigmoid_fp32 -> q_to_int32] -> dq ->
|
|
# q_to_uint8 -> [dq -> tanh_fp32 -> q_to_uint8] -> dq (tanh_dq)
|
|
#
|
|
# Complete reference model with add should be:
|
|
# [(linear2_dq, tanh_dq) -> add_fp32 -> q_to_uint8] -> dq -> fp32_output
|
|
|
|
target_to_expected_dtypes = {
|
|
"linear1": torch.qint32,
|
|
"linear2": torch.quint8,
|
|
"sigmoid": torch.qint32,
|
|
"tanh": torch.quint8,
|
|
torch.add: torch.quint8,
|
|
}
|
|
# Find the patterns [dq - op_fp32 - q_to_specific_dtype] in the graph
|
|
linear2_node = tanh_node = None
|
|
for node in converted.graph.nodes:
|
|
if node.target not in target_to_expected_dtypes:
|
|
continue
|
|
|
|
# Match preceding dequantize
|
|
self.assertTrue(len(node.args) == 1 or len(node.args) == 2)
|
|
self.assertTrue(all(arg.target == "dequantize" for arg in node.args))
|
|
|
|
# Match following quantize with the specific dtypes
|
|
self.assertEqual(len(node.users), 1)
|
|
user = next(iter(node.users.keys()))
|
|
self.assertEqual(user.target, torch.quantize_per_tensor)
|
|
self.assertEqual(user.args[-1], target_to_expected_dtypes[node.target])
|
|
|
|
# Match [dq - torch.add(linear2_dq, tanh_dq) - q]
|
|
if node.target == "linear2":
|
|
linear2_node = node
|
|
elif node.target == "tanh":
|
|
tanh_node = node
|
|
elif node.target == torch.add:
|
|
linear2_dq, tanh_dq = node.args
|
|
self.assertEqual(tanh_dq.args[0].args[0], tanh_node)
|
|
self.assertEqual(linear2_dq.args[0].args[0], linear2_node)
|
|
|
|
def test_lowering_functional_conv_with_kwargs(self):
|
|
dim_to_op = {
|
|
1: F.conv1d,
|
|
2: F.conv2d,
|
|
3: F.conv3d,
|
|
}
|
|
dim_to_qop = {
|
|
1: torch.ops.quantized.conv1d,
|
|
2: torch.ops.quantized.conv2d,
|
|
3: torch.ops.quantized.conv3d,
|
|
}
|
|
|
|
class Mod(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, dimension):
|
|
super().__init__()
|
|
self.dim = dimension
|
|
self.op = dim_to_op[dimension]
|
|
kernel_sizes = [kernel_size] * self.dim
|
|
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_sizes))
|
|
|
|
def forward(self, input):
|
|
return self.op(input, self.weight, bias=None, stride=[1] * self.dim,
|
|
padding=[0] * self.dim, dilation=[1] * self.dim, groups=1)
|
|
|
|
for dimension in [1, 2, 3]:
|
|
model = Mod(3, 16, 3, dimension)
|
|
model.eval()
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
input_shape = (1, 3, *([8] * dimension))
|
|
example_inputs = torch.randn(input_shape)
|
|
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
|
|
prepared_model(example_inputs)
|
|
quantized_model = convert_fx(prepared_model)
|
|
# This should pass
|
|
quantized_model(example_inputs)
|
|
# Ensure the quantized model has the expected op
|
|
node_occurrence = {
|
|
ns.call_function(dim_to_qop[dimension]): 1,
|
|
}
|
|
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_lowering_functional_conv_transpose_with_kwargs(self):
|
|
dim_to_op = {
|
|
1: F.conv_transpose1d,
|
|
2: F.conv_transpose2d,
|
|
3: F.conv_transpose3d,
|
|
}
|
|
dim_to_qop = {
|
|
1: torch.ops.quantized.conv_transpose1d,
|
|
2: torch.ops.quantized.conv_transpose2d,
|
|
3: torch.ops.quantized.conv_transpose3d,
|
|
}
|
|
|
|
class Mod(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, dimension):
|
|
super().__init__()
|
|
self.dim = dimension
|
|
self.op = dim_to_op[dimension]
|
|
kernel_sizes = [kernel_size] * self.dim
|
|
self.weight = nn.Parameter(torch.randn(in_channels, out_channels, *kernel_sizes))
|
|
|
|
def forward(self, input):
|
|
return self.op(input, self.weight, bias=None, stride=[1] * self.dim,
|
|
padding=[0] * self.dim, output_padding=[0] * self.dim,
|
|
dilation=[1] * self.dim, groups=1)
|
|
|
|
for dimension in [1, 2, 3]:
|
|
model = Mod(3, 16, 3, dimension)
|
|
model.eval()
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
input_shape = (1, 3, *([8] * dimension))
|
|
example_inputs = torch.randn(input_shape)
|
|
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
|
|
prepared_model(example_inputs)
|
|
quantized_model = convert_fx(prepared_model)
|
|
# This should pass
|
|
quantized_model(example_inputs)
|
|
# Ensure the quantized model has the expected op
|
|
node_occurrence = {
|
|
ns.call_function(dim_to_qop[dimension]): 1,
|
|
}
|
|
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_lowering_functional_linear_with_kwargs(self):
|
|
class Mod(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.randn(out_channels, in_channels))
|
|
|
|
def forward(self, input):
|
|
return F.linear(input, self.weight, bias=None)
|
|
|
|
model = Mod(8, 4)
|
|
model.eval()
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
example_inputs = torch.randn(1, 8)
|
|
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
|
|
prepared_model(example_inputs)
|
|
quantized_model = convert_fx(prepared_model)
|
|
# This should pass
|
|
quantized_model(example_inputs)
|
|
# Ensure the quantized model has the expected op
|
|
node_occurrence = {
|
|
ns.call_function(torch.ops.quantized.linear): 1,
|
|
}
|
|
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_keep_original_weights(self):
|
|
class SubModule(nn.Module):
|
|
"""
|
|
A simple submodule containing a linear layer.
|
|
"""
|
|
|
|
def __init__(self, input_dim, output_dim):
|
|
super().__init__()
|
|
self.w = nn.Parameter(torch.randn(input_dim, output_dim))
|
|
self.b = nn.Parameter(torch.randn(input_dim))
|
|
|
|
def forward(self, x):
|
|
return F.linear(x, self.w, self.b)
|
|
|
|
class MainModule(nn.Module):
|
|
"""
|
|
The main module containing the submodule.
|
|
"""
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim):
|
|
super().__init__()
|
|
self.submodule_1 = SubModule(hidden_dim, input_dim)
|
|
setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim))
|
|
setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim))
|
|
setattr(self, 'submodule:4', SubModule(hidden_dim, hidden_dim))
|
|
setattr(self, 'submodule: 5', SubModule(hidden_dim, hidden_dim))
|
|
self._w = nn.Parameter(torch.randn(output_dim, hidden_dim))
|
|
|
|
def forward(self, x):
|
|
x1 = self.submodule_1(x)
|
|
x2 = getattr(self, 'submodule|2')(x1)
|
|
x3 = getattr(self, 'submodule/3')(x2)
|
|
x4 = getattr(self, 'submodule:4')(x3)
|
|
x5 = getattr(self, 'submodule: 5')(x4)
|
|
x6 = F.linear(x5, self._w)
|
|
return x6
|
|
|
|
input_dim = 10
|
|
hidden_dim = 20
|
|
output_dim = 5
|
|
model = MainModule(input_dim, hidden_dim, output_dim)
|
|
model.eval()
|
|
example_inputs = torch.randn(1, input_dim)
|
|
_ = model(*example_inputs)
|
|
qconfig_mapping = QConfigMapping().set_object_type(nn.functional.linear, float16_dynamic_qconfig)
|
|
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
|
|
prepared_model(example_inputs)
|
|
quantized_model = convert_fx(prepared_model, keep_original_weights=True)
|
|
|
|
self.assertTrue(len(quantized_model.original_weights_lookup) == 6)
|
|
self.assertTrue("submodule_1_packed_weight_0" in quantized_model.original_weights_lookup)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][0],
|
|
model.submodule_1.w
|
|
)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][1],
|
|
model.submodule_1.b
|
|
)
|
|
self.assertTrue("submodule_2_packed_weight_0" in quantized_model.original_weights_lookup)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][0],
|
|
getattr(model, "submodule|2").w
|
|
)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][1],
|
|
getattr(model, "submodule|2").b
|
|
)
|
|
self.assertTrue("submodule_3_packed_weight_0" in quantized_model.original_weights_lookup)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][0],
|
|
getattr(model, "submodule/3").w
|
|
)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][1],
|
|
getattr(model, "submodule/3").b
|
|
)
|
|
self.assertTrue("submodule_4_packed_weight_0" in quantized_model.original_weights_lookup)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][0],
|
|
getattr(model, "submodule:4").w
|
|
)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][1],
|
|
getattr(model, "submodule:4").b
|
|
)
|
|
self.assertTrue("submodule_5_packed_weight_0" in quantized_model.original_weights_lookup)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_5_packed_weight_0"][0],
|
|
getattr(model, "submodule: 5").w
|
|
)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["submodule_5_packed_weight_0"][1],
|
|
getattr(model, "submodule: 5").b
|
|
)
|
|
self.assertTrue("_packed_weight_0" in quantized_model.original_weights_lookup)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["_packed_weight_0"][0],
|
|
model._w
|
|
)
|
|
torch.testing.assert_close(
|
|
quantized_model.original_weights_lookup["_packed_weight_0"][1],
|
|
None
|
|
)
|
|
|
|
@skipIfNoFBGEMM
|
|
class TestQuantizeFxOps(QuantizationTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.custom_qconfig = torch.ao.quantization.QConfig(
|
|
activation=torch.ao.quantization.observer.HistogramObserver.with_args(
|
|
qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
|
|
),
|
|
weight=torch.ao.quantization.default_per_channel_weight_observer
|
|
)
|
|
self.common_quant_patterns = {
|
|
torch.nn.ConvTranspose1d: DefaultNodeQuantizeHandler,
|
|
torch.nn.ConvTranspose2d: DefaultNodeQuantizeHandler,
|
|
torch.nn.ELU: DefaultNodeQuantizeHandler,
|
|
torch.nn.LeakyReLU: DefaultNodeQuantizeHandler,
|
|
torch.nn.Hardswish: DefaultNodeQuantizeHandler,
|
|
torch.nn.InstanceNorm1d: DefaultNodeQuantizeHandler,
|
|
torch.nn.InstanceNorm2d: DefaultNodeQuantizeHandler,
|
|
torch.nn.InstanceNorm3d: DefaultNodeQuantizeHandler,
|
|
torch.nn.LayerNorm: DefaultNodeQuantizeHandler,
|
|
torch.nn.SiLU: DefaultNodeQuantizeHandler,
|
|
torch.nn.Mish: DefaultNodeQuantizeHandler,
|
|
torch.nn.GELU: DefaultNodeQuantizeHandler,
|
|
torch.nn.Softmax: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.elu: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.hardswish: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.instance_norm: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.layer_norm: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.leaky_relu: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.silu: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.mish: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.gelu: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.softmax: DefaultNodeQuantizeHandler,
|
|
torch.sum: DefaultNodeQuantizeHandler
|
|
}
|
|
|
|
"""Unit tests for individual ops
|
|
"""
|
|
@skipIfNoFBGEMM
|
|
def test_linear_module(self):
|
|
with override_quantized_engine('fbgemm'):
|
|
class LinearModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class LinearReLUModel(torch.nn.Module):
|
|
def __init__(self, f_relu=False):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
if f_relu:
|
|
self.relu = F.relu
|
|
else:
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
class LinearBnModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4).float()
|
|
self.bn = torch.nn.BatchNorm1d(4)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
# Test linear
|
|
data = (torch.rand((1, 30), dtype=torch.float),)
|
|
for quant_type in self.all_quant_types:
|
|
model = LinearModel()
|
|
quantized_module = nnqd.Linear if quant_type == QuantType.DYNAMIC else nnq.Linear
|
|
quantized_node = ns.call_module(quantized_module)
|
|
result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
|
|
if quant_type in self.static_quant_types:
|
|
self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
|
|
|
|
# TODO: enable test for dynamic quant
|
|
# Test linear-relu
|
|
for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
|
|
model = LinearReLUModel(f_relu)
|
|
quantized_node = ns.call_module(nniq.LinearReLU)
|
|
result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
|
|
self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
|
|
|
|
# Test linear-bn
|
|
data = (torch.rand((4, 4), dtype=torch.float),)
|
|
for quant_type in self.static_quant_types:
|
|
model = LinearBnModel()
|
|
quantized_node = ns.call_module(nnq.Linear)
|
|
result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
|
|
self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_functional_linear(self):
|
|
with override_quantized_engine('fbgemm'):
|
|
class FuncLinear(torch.nn.Module):
|
|
def __init__(self, use_bias, has_relu, f_relu):
|
|
super().__init__()
|
|
self.w = torch.randn(4, 30)
|
|
self.b = torch.randn(4)
|
|
self.use_bias = use_bias
|
|
if has_relu:
|
|
if f_relu:
|
|
self.relu_or_id = F.relu
|
|
else:
|
|
self.relu_or_id = torch.nn.ReLU()
|
|
else:
|
|
self.relu_or_id = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
if self.use_bias:
|
|
x = F.linear(x, self.w, self.b)
|
|
else:
|
|
x = F.linear(x, self.w)
|
|
x = self.relu_or_id(x)
|
|
return x
|
|
|
|
data = (torch.rand((1, 30), dtype=torch.float),)
|
|
quant_type_to_qlinear_fun = {
|
|
QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
|
|
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear),
|
|
QuantType.QAT: ns.call_function(torch.ops.quantized.linear),
|
|
}
|
|
quant_type_to_qlinear_relu_fun = {
|
|
# we don't have linear_relu_dynamic
|
|
QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic),
|
|
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu),
|
|
QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu),
|
|
}
|
|
|
|
options = itertools.product(
|
|
self.all_quant_types,
|
|
(True, False), # use_bias
|
|
(True, False), # has_relu
|
|
(True, False), # functional relu
|
|
)
|
|
for quant_type, use_bias, has_relu, f_relu in options:
|
|
# when has_relu is False, we are using an nn.Identity and
|
|
# we will insert observer/fake_quant for the output of nn.Identity since
|
|
# it is a copy node, that's why we have extra observer/fake_quant
|
|
# when has_relu is False
|
|
quant_type_to_prepare_expected_node_occurrence = {
|
|
QuantType.DYNAMIC: {
|
|
ns.call_module(torch.ao.quantization.PlaceholderObserver): 1,
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
|
|
},
|
|
# There should be 3 observers: after input, weight and activation.
|
|
# one more observer for torch.nn.Identity when there is no relu
|
|
QuantType.STATIC: {
|
|
ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3,
|
|
ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1,
|
|
},
|
|
# There should be 3 observers: after input, weight and activation.
|
|
QuantType.QAT: {
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4,
|
|
},
|
|
}
|
|
model = FuncLinear(use_bias, has_relu, f_relu)
|
|
if has_relu:
|
|
qlinear_fun = quant_type_to_qlinear_relu_fun[quant_type]
|
|
else:
|
|
qlinear_fun = quant_type_to_qlinear_fun[quant_type]
|
|
|
|
if quant_type != QuantType.DYNAMIC:
|
|
num_dequantize = 1
|
|
else:
|
|
# we will have an extra quantize_per_tensor_dynamic + dequantize for
|
|
# nn.Identity right now, but it will be fixed after we use
|
|
# backend_config to configure the default pt backend
|
|
num_dequantize = int(not has_relu)
|
|
|
|
convert_node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0,
|
|
qlinear_fun: 1,
|
|
ns.call_method("dequantize"): num_dequantize if quant_type != QuantType.DYNAMIC else 0,
|
|
}
|
|
prepare_expected_node_occurrence = \
|
|
quant_type_to_prepare_expected_node_occurrence[quant_type]
|
|
result_dict = self.checkGraphModeFxOp(
|
|
model, data, quant_type, qlinear_fun,
|
|
prepare_expected_node_occurrence=prepare_expected_node_occurrence,
|
|
expected_node_occurrence=convert_node_occurrence)
|
|
if quant_type != QuantType.DYNAMIC:
|
|
self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
|
|
# Ensure packed weights in lowered models are folded
|
|
self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys())
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_linear_dynamic_fp16(self):
|
|
with override_quantized_engine('fbgemm'):
|
|
class FuncLinear(torch.nn.Module):
|
|
def __init__(self, use_bias, has_relu, f_relu):
|
|
super().__init__()
|
|
self.w = torch.randn(4, 30)
|
|
self.b = torch.randn(4)
|
|
self.use_bias = use_bias
|
|
if has_relu:
|
|
if f_relu:
|
|
self.relu = F.relu
|
|
else:
|
|
self.relu = torch.nn.ReLU()
|
|
else:
|
|
self.relu = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
if self.use_bias:
|
|
x = F.linear(x, self.w, self.b)
|
|
else:
|
|
x = F.linear(x, self.w)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
data = (torch.rand((1, 30), dtype=torch.float),)
|
|
options = itertools.product(
|
|
(True, False), # use_bias
|
|
(True, False), # has_relu
|
|
(True, False), # functional relu
|
|
(True, False), # is_reference
|
|
)
|
|
for use_bias, has_relu, f_relu, is_reference in options:
|
|
model = FuncLinear(use_bias, has_relu, f_relu)
|
|
if is_reference:
|
|
qlinear_fun = ns.call_function(torch.nn.functional.linear)
|
|
else:
|
|
if has_relu:
|
|
qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16)
|
|
else:
|
|
qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16)
|
|
prepare_node_occurrence = {
|
|
# activation and weight
|
|
ns.call_module(torch.ao.quantization.PlaceholderObserver): 2
|
|
}
|
|
convert_node_occurrence = {
|
|
qlinear_fun: 1,
|
|
# weight
|
|
ns.call_method("to"): 1 if is_reference else 0
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
model, data, QuantType.DYNAMIC, qlinear_fun,
|
|
is_reference=is_reference,
|
|
custom_qconfig_dict={"": float16_dynamic_qconfig},
|
|
prepare_expected_node_occurrence=prepare_node_occurrence,
|
|
expected_node_occurrence=convert_node_occurrence)
|
|
|
|
def test_linear_static_fp16(self):
|
|
class FuncLinear(torch.nn.Module):
|
|
def __init__(self, use_bias, has_relu, f_relu):
|
|
super().__init__()
|
|
self.w = torch.randn(4, 30)
|
|
self.b = torch.randn(4)
|
|
self.use_bias = use_bias
|
|
if has_relu:
|
|
if f_relu:
|
|
self.relu = F.relu
|
|
else:
|
|
self.relu = torch.nn.ReLU()
|
|
else:
|
|
self.relu = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
if self.use_bias:
|
|
x = F.linear(x, self.w, self.b)
|
|
else:
|
|
x = F.linear(x, self.w)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
data = (torch.rand((1, 30), dtype=torch.float),)
|
|
options = itertools.product(
|
|
(True, False), # use_bias
|
|
(True, False), # has_relu
|
|
(True, False), # functional relu
|
|
(True, False), # is_reference
|
|
)
|
|
backend_config = get_test_only_legacy_native_backend_config()
|
|
for use_bias, has_relu, f_relu, is_reference in options:
|
|
model = FuncLinear(use_bias, has_relu, f_relu)
|
|
linear_fun = ns.call_function(torch.nn.functional.linear)
|
|
# when has_relu is False, we are using an nn.Identity and
|
|
# we will insert observer/fake_quant for the output of nn.Identity since
|
|
# it is a copy node, that's why we have extra observer/fake_quant
|
|
# when has_relu is False
|
|
prepare_node_occurrence = {
|
|
# activation, weight, bias and output
|
|
ns.call_module(torch.ao.quantization.PlaceholderObserver): 3 + int(use_bias) + int(not has_relu),
|
|
}
|
|
# We have extra to and dequantize when is_reference is True
|
|
# and has_relu is False since when has_relu is False, we
|
|
# have an nn.Identity in the model, which is a CopyNode
|
|
# and we would add extra quant - dequant for CopyNode in
|
|
# reference patterns
|
|
convert_node_occurrence = {
|
|
# we don't support static fp16 ops, so the linear function
|
|
# is unfused
|
|
linear_fun: 1,
|
|
# activation, weight, bias and output
|
|
ns.call_method("to"): 3 + int(use_bias) + int(not has_relu and is_reference),
|
|
ns.call_method("dequantize"): 3 + int(use_bias) + int(not has_relu and is_reference)
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
model, data, QuantType.DYNAMIC, linear_fun,
|
|
is_reference=is_reference,
|
|
custom_qconfig_dict={"": float16_static_qconfig},
|
|
prepare_expected_node_occurrence=prepare_node_occurrence,
|
|
expected_node_occurrence=convert_node_occurrence,
|
|
backend_config=backend_config)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_conv_module(self):
|
|
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
|
|
|
|
class ConvWrapper(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
options = itertools.product([1, 2, 3], self.static_quant_types)
|
|
quantized_nodes = {
|
|
# dim
|
|
1: ns.call_module(nnq.Conv1d),
|
|
2: ns.call_module(nnq.Conv2d),
|
|
3: ns.call_module(nnq.Conv3d),
|
|
}
|
|
for dim, quant_type in options:
|
|
self.checkGraphModeFxOp(
|
|
ConvWrapper(dim), self.img_data_dict[dim], quant_type,
|
|
quantized_nodes[dim])
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_functional_conv(self):
|
|
with override_quantized_engine('fbgemm'):
|
|
""" Test for function conv and functional conv + relu
|
|
"""
|
|
convs = {
|
|
1: torch.nn.functional.conv1d,
|
|
2: torch.nn.functional.conv2d,
|
|
3: torch.nn.functional.conv3d,
|
|
}
|
|
|
|
class FuncConv(torch.nn.Module):
|
|
def __init__(self, dim, use_bias, has_relu, f_relu):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.w = torch.randn(tuple([3] * (dim + 2)))
|
|
self.b = torch.randn(3) if use_bias else None
|
|
self.stride = tuple([1] * dim)
|
|
self.padding = tuple([0] * dim)
|
|
self.dilation = tuple([1] * dim)
|
|
self.groups = 1
|
|
self.use_bias = use_bias
|
|
if has_relu:
|
|
if f_relu:
|
|
self.relu = F.relu
|
|
else:
|
|
self.relu = torch.nn.ReLU()
|
|
else:
|
|
self.relu = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
quant_type_to_qconv_fun = {
|
|
QuantType.STATIC: {
|
|
1: ns.call_function(torch.ops.quantized.conv1d),
|
|
2: ns.call_function(torch.ops.quantized.conv2d),
|
|
3: ns.call_function(torch.ops.quantized.conv3d)
|
|
},
|
|
QuantType.QAT: {
|
|
1: ns.call_function(torch.ops.quantized.conv1d),
|
|
2: ns.call_function(torch.ops.quantized.conv2d),
|
|
3: ns.call_function(torch.ops.quantized.conv3d)
|
|
},
|
|
}
|
|
quant_type_to_qconv_relu_fun = {
|
|
QuantType.STATIC: {
|
|
1: ns.call_function(torch.ops.quantized.conv1d_relu),
|
|
2: ns.call_function(torch.ops.quantized.conv2d_relu),
|
|
3: ns.call_function(torch.ops.quantized.conv3d_relu)
|
|
},
|
|
QuantType.QAT: {
|
|
1: ns.call_function(torch.ops.quantized.conv1d_relu),
|
|
2: ns.call_function(torch.ops.quantized.conv2d_relu),
|
|
3: ns.call_function(torch.ops.quantized.conv3d_relu)
|
|
},
|
|
}
|
|
|
|
options = itertools.product(
|
|
[1, 2, 3], # dims
|
|
self.static_quant_types,
|
|
(True, False), # use_bias
|
|
(True, False), # has_relu
|
|
(True, False), # functional relu
|
|
)
|
|
for dim, quant_type, use_bias, has_relu, f_relu in options:
|
|
# when has_relu is False, we are using an nn.Identity and
|
|
# we will insert observer/fake_quant for the output of nn.Identity since
|
|
# it is a copy node, that's why we have extra observer/fake_quant
|
|
# when has_relu is False
|
|
quant_type_to_prepare_expected_node_occurrence = {
|
|
QuantType.DYNAMIC: {},
|
|
# There should be 3 observers: after input, weight and activation.
|
|
QuantType.STATIC: {
|
|
ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3,
|
|
ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1,
|
|
},
|
|
# There should be 3 observers: after input, weight and activation.
|
|
QuantType.QAT: {
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4,
|
|
},
|
|
}
|
|
data_dims = [2, 3] + [4] * dim
|
|
data = (torch.randn(tuple(data_dims), dtype=torch.float),)
|
|
model = FuncConv(dim, use_bias, has_relu, f_relu)
|
|
if has_relu:
|
|
qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim]
|
|
else:
|
|
qconv_fun = quant_type_to_qconv_fun[quant_type][dim]
|
|
|
|
convert_node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
qconv_fun: 1,
|
|
ns.call_method("dequantize"): 1
|
|
}
|
|
prepare_expected_node_occurrence = \
|
|
quant_type_to_prepare_expected_node_occurrence[quant_type]
|
|
result_dict = self.checkGraphModeFxOp(
|
|
model, data, quant_type, qconv_fun,
|
|
prepare_expected_node_occurrence=prepare_expected_node_occurrence,
|
|
expected_node_occurrence=convert_node_occurrence)
|
|
if quant_type != QuantType.DYNAMIC:
|
|
self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
|
|
# Ensure packed weights in lowered models are folded
|
|
self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys())
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_conv_relu(self):
|
|
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
|
|
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
|
|
|
|
class ConvNdRelu(torch.nn.Module):
|
|
def __init__(self, dim, inplace):
|
|
super().__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.conv(x))
|
|
|
|
class ConvNdFunctionalRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.conv(x))
|
|
|
|
class ConvNdInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.conv(x), True)
|
|
|
|
options = itertools.product([1, 2, 3], self.static_quant_types)
|
|
quantized_nodes = {
|
|
# dim
|
|
1: ns.call_module(nniq.ConvReLU1d),
|
|
2: ns.call_module(nniq.ConvReLU2d),
|
|
3: ns.call_module(nniq.ConvReLU3d),
|
|
}
|
|
for dim, quant_type in options:
|
|
for m in [ConvNdRelu(dim, True),
|
|
ConvNdRelu(dim, False),
|
|
ConvNdFunctionalRelu(dim),
|
|
ConvNdInplaceFunctionalRelu(dim)]:
|
|
self.checkGraphModeFxOp(
|
|
m, self.img_data_dict[dim], quant_type,
|
|
quantized_nodes[dim])
|
|
|
|
|
|
def _test_binary_op_int8_impl(self, binary_op, ibinary_op, quantized_op):
|
|
data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
|
|
torch.randn(1, 1, 1, 1, dtype=torch.float))
|
|
options = itertools.product([True, False], [True, False], [True, False])
|
|
quant_type = QuantType.STATIC
|
|
# testing for default int8 static quant
|
|
for is_inplace, is_scalar, is_reference in options:
|
|
if is_reference:
|
|
node_list = [
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(binary_op),
|
|
ns.call_function(torch.quantize_per_tensor)
|
|
]
|
|
quantized_node = None
|
|
else:
|
|
node_list = None
|
|
quantized_node = ns.call_function(quantized_op)
|
|
|
|
self.checkGraphModeFxOp(
|
|
BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
|
|
quantized_node, expected_node_list=node_list, is_reference=is_reference)
|
|
# This tests the binary op should be quantized even when it is not feed with a
|
|
# quantized input
|
|
self.checkGraphModeFxOp(
|
|
BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar),
|
|
data, quant_type, quantized_node,
|
|
expected_node_list=node_list, is_reference=is_reference)
|
|
|
|
|
|
def _test_binary_op_float16_impl(self, binary_op, ibinary_op):
|
|
data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
|
|
torch.randn(1, 1, 1, 1, dtype=torch.float))
|
|
quant_type = QuantType.STATIC
|
|
# testing for fp16 static quant
|
|
# we are producing fp16 patterns
|
|
options = itertools.product([True, False], [True, False])
|
|
custom_qconfig_dict = {
|
|
"object_type": [(binary_op, float16_static_qconfig)]
|
|
}
|
|
backend_config = get_test_only_legacy_native_backend_config()
|
|
for is_inplace, is_scalar in options:
|
|
node_occurrence = {
|
|
# output_conv1, output_add1, output_add2 for scalar
|
|
# output_conv1, output_conv2, output_add1, output_add2 for non-scalar
|
|
ns.call_method("to"): 3 if is_scalar else 4
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
|
|
expected_node_occurrence=node_occurrence,
|
|
custom_qconfig_dict=custom_qconfig_dict,
|
|
backend_config=backend_config)
|
|
|
|
node_occurrence = {
|
|
# input_add, output_add for scalar
|
|
# input_add1, input_add2, output_add for non-scalar
|
|
ns.call_method("to"): 2 if is_scalar else 3
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
|
|
expected_node_occurrence=node_occurrence,
|
|
custom_qconfig_dict=custom_qconfig_dict,
|
|
backend_config=backend_config)
|
|
|
|
def _test_binary_op_relu_int8_impl(self, binary_op, ibinary_op, quantized_op):
|
|
data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
|
|
torch.rand((1, 1, 1, 1), dtype=torch.float))
|
|
quant_type = QuantType.STATIC
|
|
quantized_node = ns.call_function(quantized_op)
|
|
options = itertools.product(
|
|
[True, False], [nn.ReLU, F.relu, torch.relu], [True, False])
|
|
for is_inplace_op, relu_callable, is_scalar in options:
|
|
model = BinaryOpRelu(
|
|
binary_op, ibinary_op, is_inplace_op, relu_callable, is_scalar)
|
|
self.checkGraphModeFxOp(
|
|
model, data, quant_type, quantized_node)
|
|
|
|
def _test_binary_op_relu_float16_impl(self, binary_op, ibinary_op):
|
|
data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
|
|
torch.rand((1, 1, 1, 1), dtype=torch.float))
|
|
quant_type = QuantType.STATIC
|
|
options = itertools.product(
|
|
[True, False], [nn.ReLU, F.relu, torch.relu], [True, False])
|
|
custom_qconfig_dict = {
|
|
"": float16_static_qconfig,
|
|
"object_type": [(torch.nn.Conv2d, None)]
|
|
}
|
|
backend_config = get_test_only_legacy_native_backend_config()
|
|
for is_inplace_op, is_functional_relu, is_scalar in options:
|
|
node_occurrence = {
|
|
ns.call_method("to"): 3 if is_scalar else 4
|
|
}
|
|
model = BinaryOpRelu(
|
|
binary_op, ibinary_op, is_inplace_op, is_functional_relu, is_scalar)
|
|
self.checkGraphModeFxOp(
|
|
model, data, quant_type, custom_qconfig_dict=custom_qconfig_dict,
|
|
expected_node_occurrence=node_occurrence,
|
|
backend_config=backend_config)
|
|
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_add(self):
|
|
self._test_binary_op_int8_impl(
|
|
operator.add, operator.iadd, torch.ops.quantized.add)
|
|
self._test_binary_op_float16_impl(
|
|
operator.add, operator.iadd)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_sub(self):
|
|
self._test_binary_op_float16_impl(operator.sub, operator.isub)
|
|
self._test_binary_op_float16_impl(torch.sub, None)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_div(self):
|
|
self._test_binary_op_float16_impl(operator.truediv, operator.itruediv)
|
|
self._test_binary_op_float16_impl(torch.div, None)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_mul(self):
|
|
self._test_binary_op_int8_impl(
|
|
operator.mul, operator.imul, torch.ops.quantized.mul)
|
|
self._test_binary_op_float16_impl(operator.mul, operator.imul)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_sum(self):
|
|
class Sum(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = torch.sum(x, [1], keepdim=True)
|
|
x = torch.sum(x, [1])
|
|
return x
|
|
|
|
data = torch.randn(1, 2, 3, 4, dtype=torch.float)
|
|
quant_type = QuantType.STATIC
|
|
# testing for fp16 static quant
|
|
# we are producing fp16 patterns
|
|
custom_qconfig_dict = {
|
|
"object_type": [(torch.sum, float16_static_qconfig)]
|
|
}
|
|
node_occurrence = {
|
|
# input_sum1, output_sum1, output_sum2
|
|
ns.call_method("to"): 3
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
Sum(), data, quant_type,
|
|
expected_node_occurrence=node_occurrence,
|
|
custom_qconfig_dict=custom_qconfig_dict)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_bmm(self):
|
|
class BMMMethod(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.bmm(y)
|
|
|
|
data = (torch.randn(1, 1, 1, dtype=torch.float),
|
|
torch.randn(1, 1, 1, dtype=torch.float))
|
|
quant_type = QuantType.STATIC
|
|
# testing for fp16 static quant
|
|
# we are producing fp16 patterns
|
|
custom_qconfig_dict = {
|
|
"object_type": [(torch.bmm, float16_static_qconfig),
|
|
("bmm", float16_static_qconfig)]
|
|
}
|
|
node_occurrence = {
|
|
# input_bmm1, input_bmm2, output_bmm
|
|
ns.call_method("to"): 3
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
BinaryOpNonQuantizedInput(torch.bmm, None, False, False), data, quant_type,
|
|
expected_node_occurrence=node_occurrence,
|
|
custom_qconfig_dict=custom_qconfig_dict)
|
|
|
|
# TODO: support call_method("bmm")
|
|
# we can transform call_method("bmm") to call_function(torch.bmm)
|
|
# self.checkGraphModeFxOp(
|
|
# BMMMethod(), data, quant_type,
|
|
# expected_node_occurrence=node_occurrence,
|
|
# custom_qconfig_dict=custom_qconfig_dict,
|
|
# print_debug_info=True)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_add_relu(self):
|
|
self._test_binary_op_relu_int8_impl(
|
|
operator.add, operator.iadd, torch.ops.quantized.add_relu)
|
|
self._test_binary_op_relu_float16_impl(
|
|
operator.add, operator.iadd)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_add_relu_multiple_uses_of_relu(self):
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sub = Sub()
|
|
|
|
def forward(self, x, y):
|
|
x = x + y
|
|
x = self.sub.relu(x)
|
|
x = x + y
|
|
x = self.sub.relu(x)
|
|
return x
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(3), torch.randn(3))
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_function(torch.ops.quantized.add_relu): 2,
|
|
ns.call_method("dequantize"): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
# check the model is scriptable
|
|
m = torch.jit.script(m)
|
|
# check the model is runnable
|
|
m(*example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_mul_relu(self):
|
|
self._test_binary_op_relu_int8_impl(
|
|
operator.mul, operator.imul, torch.ops.quantized.mul_relu)
|
|
self._test_binary_op_relu_float16_impl(
|
|
operator.mul, operator.imul)
|
|
|
|
# TODO(future PR): make more generic
|
|
def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence):
|
|
qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
|
|
mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
self.checkGraphModuleNodes(
|
|
mp, expected_node_occurrence=expected_node_occurrence)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_add_qat(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = torch.add(x, 1.0)
|
|
x = self.conv1(x)
|
|
x = torch.add(x, 1.0)
|
|
x = torch.relu(x)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
m = M()
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
expected_node_occurrence = {
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
|
|
}
|
|
self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_mul_qat(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = torch.mul(x, 1.0)
|
|
x = self.conv1(x)
|
|
x = torch.mul(x, 1.0)
|
|
x = torch.relu(x)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
m = M()
|
|
example_inputs = (torch.randn(1, 1, 1, 1),)
|
|
expected_node_occurrence = {
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
|
|
}
|
|
self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence)
|
|
|
|
def test_int8_input_no_unnecessary_fq(self):
|
|
"""
|
|
If the inputs to the graph are quantized and the only node
|
|
does not need an activation observer, verifies that the
|
|
activation observer is not inserted.
|
|
"""
|
|
class M(nn.Module):
|
|
def __init__(self, scalar):
|
|
super().__init__()
|
|
self.scalar = scalar
|
|
self.add_func = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
def forward(self, x):
|
|
return self.add_func.add_scalar(x, self.scalar)
|
|
|
|
m = M(0.5)
|
|
mp = torch.ao.quantization.quantize_fx.prepare_qat_fx(
|
|
m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')},
|
|
example_inputs=(torch.randn(1),),
|
|
prepare_custom_config={"input_quantized_idxs": [0]})
|
|
expected_node_occurrence = {
|
|
ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
mp, expected_node_occurrence=expected_node_occurrence)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_cat(self):
|
|
""" quantization of the output of cat will depend on the
|
|
input of cat. we only quantize the output of cat when its inputs are quantized.
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
return torch.cat([x, y], 1)
|
|
|
|
example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randn(1, 2, 5, 5, dtype=torch.float))
|
|
quantized_node = ns.call_function(torch.cat)
|
|
options = itertools.product(self.static_quant_types, [True, False])
|
|
for quant_type, is_reference in options:
|
|
if is_reference:
|
|
converted_node_list = [
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(torch.cat),
|
|
ns.call_function(torch.quantize_per_tensor)
|
|
]
|
|
converted_node_occurrence = {
|
|
# inputs and outputs of the two conv, and output of cat
|
|
ns.call_method("dequantize"): 5,
|
|
ns.call_function(torch.cat): 1,
|
|
# inputs and outputs of the two conv, and output of cat
|
|
ns.call_function(torch.quantize_per_tensor): 5,
|
|
}
|
|
else:
|
|
converted_node_list = None
|
|
converted_node_occurrence = {
|
|
# output of cat
|
|
ns.call_method("dequantize"): 1,
|
|
ns.call_function(torch.cat): 1,
|
|
# for two inputs
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
}
|
|
|
|
self.checkGraphModeFxOp(
|
|
M(),
|
|
example_inputs,
|
|
quant_type,
|
|
quantized_node,
|
|
expected_node_list=converted_node_list,
|
|
expected_node_occurrence=converted_node_occurrence,
|
|
is_reference=is_reference)
|
|
|
|
# check cat is using the same observer for input and output
|
|
m = M().eval()
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
# two inputs and one output of torch.cat are using same observer, so we have
|
|
# 2 observers that's replicated
|
|
all_observers = len(dict(m.named_modules(remove_duplicate=False)))
|
|
distinct_observers = len(dict(m.named_modules()))
|
|
self.assertEqual(all_observers, distinct_observers + 2)
|
|
# make sure the converted model runs
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qbatch_norm(self):
|
|
bn_module = {
|
|
# TODO: quantized batchnorm 1d module is missing
|
|
# 1 : torch.nn.BatchNorm1d,
|
|
2 : torch.nn.BatchNorm2d,
|
|
3 : torch.nn.BatchNorm3d,
|
|
}
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
|
|
def forward(self, x):
|
|
return self.bn(x)
|
|
|
|
options = itertools.product(self.static_quant_types, [2, 3], [True, False])
|
|
quantized_nodes = {
|
|
False: {
|
|
# 1: ns.call_module(nnq.BatchNorm1d),
|
|
2: ns.call_module(nnq.BatchNorm2d),
|
|
3: ns.call_module(nnq.BatchNorm3d),
|
|
},
|
|
True: {
|
|
# 1: ns.call_module(nn.BatchNorm1d),
|
|
2: ns.call_module(nn.BatchNorm2d),
|
|
3: ns.call_module(nn.BatchNorm3d),
|
|
}
|
|
}
|
|
for quant_type, dim, is_reference in options:
|
|
self.checkGraphModeFxOp(
|
|
M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[is_reference][dim], is_reference=is_reference)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qbatch_norm_relu(self):
|
|
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
|
|
|
|
class BNRelu(torch.nn.Module):
|
|
def __init__(self, dim, inplace):
|
|
super().__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
self.relu = torch.nn.ReLU(inplace=inplace)
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.bn(x))
|
|
|
|
class BNFuncRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.bn(x), False)
|
|
|
|
class BNFuncInplaceRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.bn(x), True)
|
|
|
|
options = itertools.product(self.static_quant_types, [2, 3], [True, False])
|
|
quantized_nodes = {
|
|
True: {
|
|
2: ns.call_module(nni.BNReLU2d),
|
|
3: ns.call_module(nni.BNReLU3d),
|
|
},
|
|
False: {
|
|
2: ns.call_module(nniq.BNReLU2d),
|
|
3: ns.call_module(nniq.BNReLU3d),
|
|
}
|
|
}
|
|
for quant_type, dim, is_reference in options:
|
|
for instance in [BNRelu(dim, True), BNRelu(dim, False),
|
|
BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
|
|
self.checkGraphModeFxOp(
|
|
instance, self.img_data_dict[dim], quant_type,
|
|
quantized_nodes[is_reference][dim], is_reference=is_reference)
|
|
|
|
def _test_activation_impl(
|
|
self, float_module, float_op, quantized_module, quantized_op):
|
|
''' Test for activation op(with inplace options), float_op can be
|
|
torch op or functional op
|
|
'''
|
|
class M(torch.nn.Module):
|
|
def __init__(self, is_module, inplace):
|
|
super().__init__()
|
|
self.is_module = is_module
|
|
self.inplace = inplace
|
|
if self.is_module:
|
|
self.op = float_module(self.inplace)
|
|
else:
|
|
self.op = float_op
|
|
|
|
def forward(self, input):
|
|
if self.is_module:
|
|
return self.op(input)
|
|
else:
|
|
return self.op(input, self.inplace)
|
|
|
|
options = itertools.product([True, False], [True, False], self.static_quant_types, [True, False])
|
|
quantized_nodes = {
|
|
# is_module
|
|
True: {
|
|
# is_reference
|
|
True: ns.call_module(float_module),
|
|
False: ns.call_module(quantized_module),
|
|
},
|
|
False: {
|
|
True: ns.call_function(float_op),
|
|
False: ns.call_function(quantized_op),
|
|
}
|
|
}
|
|
|
|
for is_module, is_inplace, quant_type, is_reference in options:
|
|
self.checkGraphModeFxOp(
|
|
M(is_module, is_inplace), self.img_data_2d,
|
|
quant_type, quantized_nodes[is_module][is_reference], is_reference=is_reference)
|
|
|
|
def test_hardswish(self):
|
|
self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish)
|
|
|
|
def test_elu(self):
|
|
self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu)
|
|
|
|
def test_leaky_relu(self):
|
|
self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu)
|
|
|
|
def test_prelu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, num_param: int):
|
|
super().__init__()
|
|
self.op = torch.nn.PReLU(num_parameters=num_param)
|
|
|
|
def forward(self, input):
|
|
return self.op(input)
|
|
|
|
X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]]
|
|
options = itertools.product([1, 4], self.static_quant_types, [True, False])
|
|
quantized_nodes = {
|
|
# is_reference
|
|
True: ns.call_module(torch.nn.PReLU),
|
|
False: ns.call_module(torch.ao.nn.quantized.PReLU),
|
|
}
|
|
|
|
for num_parameter, quant_type, is_reference in options:
|
|
self.checkGraphModeFxOp(
|
|
M(num_parameter), X, quant_type, quantized_nodes[is_reference],
|
|
is_reference=is_reference)
|
|
|
|
def _test_norm_impl(
|
|
self, float_module, float_op, op_args, data, quantized_module, quantized_op,
|
|
skip_op_arg_for_functional=False):
|
|
''' Test for normalization op, float_op can be torch op or functional op,
|
|
op_args is a list of positional argument for the module/op
|
|
'''
|
|
class M(torch.nn.Module):
|
|
def __init__(self, is_module):
|
|
super().__init__()
|
|
self.is_module = is_module
|
|
if self.is_module:
|
|
self.op = float_module(*op_args)
|
|
else:
|
|
self.op = float_op
|
|
|
|
def forward(self, input):
|
|
if self.is_module:
|
|
return self.op(input)
|
|
else:
|
|
args = [input]
|
|
if not skip_op_arg_for_functional:
|
|
args += op_args
|
|
return self.op(*args)
|
|
|
|
options = itertools.product([True, False], self.static_quant_types)
|
|
quantized_nodes = {
|
|
# is_module
|
|
True: ns.call_module(quantized_module),
|
|
False: ns.call_function(quantized_op),
|
|
}
|
|
|
|
for is_module, quant_type in options:
|
|
self.checkGraphModeFxOp(
|
|
M(is_module), data, quant_type, quantized_nodes[is_module])
|
|
|
|
def _test_norm_float16_impl(
|
|
self, float_module, float_op, op_args, data,
|
|
skip_op_arg_for_functional=False):
|
|
''' Test for normalization op, float_op can be torch op or functional op,
|
|
op_args is a list of positional argument for the module/op
|
|
'''
|
|
class M(torch.nn.Module):
|
|
def __init__(self, is_module):
|
|
super().__init__()
|
|
self.is_module = is_module
|
|
if self.is_module:
|
|
self.op = float_module(*op_args)
|
|
else:
|
|
self.op = float_op
|
|
|
|
def forward(self, input):
|
|
if self.is_module:
|
|
return self.op(input)
|
|
else:
|
|
args = [input]
|
|
if not skip_op_arg_for_functional:
|
|
args += op_args
|
|
return self.op(*args)
|
|
|
|
options = itertools.product([True, False], self.static_quant_types)
|
|
qconfig_dict = {
|
|
"object_type": [
|
|
(float_module, float16_static_qconfig),
|
|
(float_op, float16_static_qconfig)
|
|
]
|
|
}
|
|
node_occurrence = {
|
|
ns.call_method("to"): 2
|
|
}
|
|
for is_module, quant_type in options:
|
|
self.checkGraphModeFxOp(
|
|
M(is_module), data, quant_type, custom_qconfig_dict=qconfig_dict, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_layer_norm(self):
|
|
data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
|
|
self._test_norm_impl(
|
|
nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm)
|
|
|
|
def test_instance_norm(self):
|
|
data_1d = (torch.rand((1, 4, 5), dtype=torch.float),)
|
|
data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),)
|
|
data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),)
|
|
data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d}
|
|
instance_norm_modules = {1 : nn.InstanceNorm1d,
|
|
2 : nn.InstanceNorm2d,
|
|
3 : nn.InstanceNorm3d}
|
|
quantized_instance_norm_modules = {
|
|
1 : nnq.InstanceNorm1d,
|
|
2 : nnq.InstanceNorm2d,
|
|
3 : nnq.InstanceNorm3d
|
|
}
|
|
for dim in [1, 2, 3]:
|
|
data = data_dict[dim]
|
|
module = instance_norm_modules[dim]
|
|
quantized_module = quantized_instance_norm_modules[dim]
|
|
self._test_norm_impl(
|
|
module, F.instance_norm, [4], data,
|
|
quantized_module, torch.ops.quantized.instance_norm,
|
|
skip_op_arg_for_functional=True)
|
|
|
|
def test_norm_weight_bias(self):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = Linear()
|
|
self.scale = torch.randn(5, 5)
|
|
self.bias = torch.randn(5, 5)
|
|
|
|
def forward(self, x):
|
|
x1 = self.mods1(x)
|
|
y = F.layer_norm(x1, [5, 5], weight=self.scale, bias=self.bias)
|
|
return y
|
|
|
|
model = M()
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_function(torch.ops.quantized.linear): 1,
|
|
ns.call_function(torch.ops.quantized.layer_norm): 1,
|
|
ns.call_method("dequantize"): 1,
|
|
}
|
|
|
|
self.checkGraphModeFxOp(
|
|
model,
|
|
(torch.rand(5, 5),),
|
|
QuantType.STATIC,
|
|
expected_node_occurrence=expected_occurrence,
|
|
custom_qconfig_dict=get_default_qconfig_mapping().to_dict()
|
|
)
|
|
|
|
def _test_default_node_quant_handler_ops(
|
|
self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None
|
|
):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, mod, func):
|
|
super().__init__()
|
|
self.module = mod()
|
|
self.functional = func
|
|
|
|
def forward(self, x):
|
|
x = self.module(x)
|
|
x = self.functional(x)
|
|
return x
|
|
|
|
if node_list is None:
|
|
node_list = []
|
|
if additional_quant_pattern_dict is None:
|
|
additional_quant_pattern_dict = {}
|
|
|
|
data = torch.randn((2, 2, 2, 2))
|
|
quant_type = QuantType.STATIC
|
|
prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict}
|
|
qconfig_dict = {"": qconfig}
|
|
|
|
m = M(module, functional).eval()
|
|
m_prep = prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict)
|
|
m_prep(data)
|
|
convert_fn = convert_to_reference_fx if is_reference else convert_fx
|
|
m_quant = convert_fn(m_prep, is_reference=is_reference)
|
|
m_quant(data)
|
|
|
|
self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
|
|
|
|
@unittest.skip("TODO: reenable with backend_config api")
|
|
def test_gelu_normal(self):
|
|
module = torch.nn.GELU
|
|
functional = torch.nn.functional.gelu
|
|
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
|
|
is_reference = False
|
|
node_list = [
|
|
ns.call_module(module),
|
|
ns.call_function(functional),
|
|
]
|
|
self._test_default_node_quant_handler_ops(
|
|
module, functional, qconfig, is_reference, node_list)
|
|
|
|
@unittest.skip("TODO: reenable with backend_config api")
|
|
def test_softmax_normal(self):
|
|
module = torch.nn.Softmax
|
|
functional = torch.nn.functional.softmax
|
|
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
|
|
is_reference = False
|
|
node_list = [
|
|
ns.call_module(torch.ao.nn.quantized.Softmax),
|
|
ns.call_function(functional),
|
|
]
|
|
self._test_default_node_quant_handler_ops(
|
|
module, functional, qconfig, is_reference, node_list)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_gelu_reference(self):
|
|
module = torch.nn.GELU
|
|
functional = torch.nn.functional.gelu
|
|
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
|
|
is_reference = True
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(module),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize'),
|
|
ns.call_function(functional),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize')
|
|
]
|
|
# TODO: change these to use backend_config
|
|
additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.gelu: DefaultNodeQuantizeHandler}
|
|
self._test_default_node_quant_handler_ops(
|
|
module, functional, qconfig, is_reference, node_list, additional_patterns)
|
|
|
|
self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
|
|
additional_quant_pattern_dict=self.common_quant_patterns)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_softmax_reference(self):
|
|
module = torch.nn.Softmax
|
|
functional = torch.nn.functional.softmax
|
|
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
|
|
is_reference = True
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(module),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize'),
|
|
ns.call_function(functional),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize')
|
|
]
|
|
additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler,
|
|
torch.nn.functional.softmax: DefaultNodeQuantizeHandler}
|
|
self._test_default_node_quant_handler_ops(
|
|
module, functional, qconfig, is_reference, node_list, additional_patterns)
|
|
|
|
self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
|
|
additional_quant_pattern_dict=self.common_quant_patterns)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_silu_reference(self):
|
|
module = torch.nn.SiLU
|
|
functional = torch.nn.functional.silu
|
|
qconfig = float16_static_qconfig
|
|
is_reference = True
|
|
node_list = [
|
|
ns.call_method("to"),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(module),
|
|
ns.call_method("to"),
|
|
ns.call_method('dequantize'),
|
|
ns.call_function(functional),
|
|
ns.call_method("to"),
|
|
ns.call_method('dequantize')
|
|
]
|
|
self._test_default_node_quant_handler_ops(
|
|
module, functional, qconfig, is_reference, node_list)
|
|
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(module),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(functional),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize")
|
|
]
|
|
self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
|
|
additional_quant_pattern_dict=self.common_quant_patterns)
|
|
|
|
@unittest.skip("This is no longer needed right now, can enable later with new api")
|
|
def test_mish_reference(self):
|
|
module = torch.nn.Mish
|
|
functional = torch.nn.functional.mish
|
|
qconfig = float16_static_qconfig
|
|
is_reference = True
|
|
node_list = [
|
|
ns.call_method("to"),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(module),
|
|
ns.call_method("to"),
|
|
ns.call_method('dequantize'),
|
|
ns.call_function(functional),
|
|
ns.call_method("to"),
|
|
ns.call_method('dequantize')
|
|
]
|
|
self._test_default_node_quant_handler_ops(
|
|
module, functional, qconfig, is_reference, node_list)
|
|
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
ns.call_module(module),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize"),
|
|
ns.call_function(functional),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize")
|
|
]
|
|
self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
|
|
additional_quant_pattern_dict=self.common_quant_patterns)
|
|
|
|
def test_bmm_int_reference(self):
|
|
""" int8 is not supported for bmm so we won't produce reference
|
|
pattern for it
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bmm = torch.bmm
|
|
|
|
def forward(self, x, y):
|
|
out = self.bmm(x, y)
|
|
return out
|
|
|
|
data_x = torch.randn((2, 2, 2,))
|
|
data_y = torch.randn((2, 2, 2,))
|
|
example_inputs = (data_x, data_y)
|
|
qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")}
|
|
is_reference = True
|
|
node_list = [
|
|
ns.call_function(torch.bmm),
|
|
]
|
|
|
|
m = M().eval()
|
|
m_prep = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m_prep(*example_inputs)
|
|
convert_fn = convert_to_reference_fx if is_reference else convert_fx
|
|
m_quant = convert_fn(m_prep)
|
|
m_quant(*example_inputs)
|
|
|
|
self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_clamp(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu6 = torch.nn.ReLU6()
|
|
self.relu6_ = torch.nn.ReLU6(True)
|
|
self.hardtanh = torch.nn.Hardtanh()
|
|
self.hardtanh_ = torch.nn.Hardtanh(inplace=True)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.relu6(x)
|
|
self.relu6_(x)
|
|
x = F.relu6(x)
|
|
x = torch.clamp(x, -3, 3)
|
|
x = x.clamp(-2.5, 2.5)
|
|
# x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready
|
|
x = self.hardtanh(x)
|
|
self.hardtanh_(x)
|
|
x = F.hardtanh(x)
|
|
return x
|
|
|
|
data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
|
|
# list of node that should occur in order
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method('dequantize')
|
|
]
|
|
for quant_type in self.static_quant_types:
|
|
self.checkGraphModeFxOp(
|
|
M(), data, quant_type, expected_node_list=node_list)
|
|
|
|
def test_fixed_qparams_ops_fp16(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
self.tanh = torch.nn.Tanh()
|
|
|
|
def forward(self, x):
|
|
x = self.sigmoid(x)
|
|
x = torch.sigmoid(x)
|
|
x = x.sigmoid()
|
|
x = self.tanh(x)
|
|
x = torch.tanh(x)
|
|
x = x.tanh()
|
|
return x
|
|
|
|
data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
|
|
quant_type = QuantType.STATIC
|
|
# TODO: use get_default_qconfig_mapping once it handles fp16
|
|
qconfig_mapping = QConfigMapping().set_global(float16_static_qconfig)
|
|
backend_config = get_test_only_legacy_native_backend_config()
|
|
node_occurrence = {
|
|
ns.call_method("to"): 7
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
M(), data, quant_type, custom_qconfig_dict=qconfig_mapping,
|
|
expected_node_occurrence=node_occurrence,
|
|
backend_config=backend_config)
|
|
|
|
def test_fixed_qparams_ops_qint8(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
self.tanh = torch.nn.Tanh()
|
|
|
|
def forward(self, x):
|
|
x = self.sigmoid(x)
|
|
x = torch.sigmoid(x)
|
|
x = x.sigmoid()
|
|
x = self.tanh(x)
|
|
x = torch.tanh(x)
|
|
x = x.tanh()
|
|
return x
|
|
|
|
data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
|
|
quant_type = QuantType.STATIC
|
|
qconfig = torch.ao.quantization.QConfig(
|
|
activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8),
|
|
weight=default_weight_observer)
|
|
qconfig_mapping = get_default_qconfig_mapping().set_global(qconfig)
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 7,
|
|
ns.call_method("dequantize"): 7
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
M(), data, quant_type, custom_qconfig_dict=qconfig_mapping,
|
|
expected_node_occurrence=node_occurrence, is_reference=True)
|
|
|
|
def test_fixed_qparams_ops_wrong_qconfig(self):
|
|
""" Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized.
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
self.tanh = torch.nn.Tanh()
|
|
|
|
def forward(self, x):
|
|
x = self.sigmoid(x)
|
|
x = torch.sigmoid(x)
|
|
x = x.sigmoid()
|
|
x = self.tanh(x)
|
|
x = torch.tanh(x)
|
|
x = x.tanh()
|
|
return x
|
|
|
|
data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
|
|
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
|
|
m = M().eval()
|
|
node_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 0,
|
|
ns.call_method("dequantize"): 0,
|
|
}
|
|
self.checkGraphModeFxOp(
|
|
m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping,
|
|
expected_node_occurrence=node_occurrence, is_reference=True)
|
|
self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid))
|
|
self.assertTrue(isinstance(m.tanh, torch.nn.Tanh))
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_general_shape_ops(self):
|
|
""" A test that checks dequantize will be swapped for
|
|
all supported general shape ops like aten::flatten
|
|
without actually checking for execution of these ops
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
|
|
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
|
|
self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
|
|
self.dropout = torch.nn.Dropout()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
# add_scalar
|
|
x = x + 3
|
|
# mul_scalar
|
|
x = x * 3
|
|
# add_scalar_out
|
|
x += 3
|
|
# mul_scalar_out
|
|
x *= 3
|
|
# add_scalar_relu
|
|
x = x + 3
|
|
x = F.relu(x)
|
|
# add_scalar_relu_out
|
|
x += 3
|
|
x = F.relu(x)
|
|
# mul_scalar_relu
|
|
x = x * 3
|
|
x = F.relu(x)
|
|
# mul_scalar_relu_out
|
|
x *= 3
|
|
x = F.relu(x)
|
|
x = self.maxpool1d(x)
|
|
x = self.maxpool2d(x)
|
|
x = self.maxpool3d(x)
|
|
x = torch.flatten(x)
|
|
x = x.reshape([-1])
|
|
x = x.resize_(1, 1, x)
|
|
x = x.view(-1)
|
|
# prim::ListConstruct
|
|
xs = [x, x]
|
|
# prim::ListUnpack
|
|
x, y = xs
|
|
# prim::TupleConstruct
|
|
xs = (x, x)
|
|
# prim::TupleUnpack
|
|
x, y = xs
|
|
x = x.transpose(1, 2)
|
|
x = x.contiguous()
|
|
# chunk is not supported since observer only supports
|
|
# observing single Tensor currently
|
|
x, y = torch.chunk(x, 2)
|
|
x = F.dropout(x)
|
|
x = self.dropout(x)
|
|
x = x.permute(0, 2, 3, 1)
|
|
x = x.repeat_interleave(3, 1)
|
|
x = torch.repeat_interleave(x, 3, 1)
|
|
x = self.relu(x)
|
|
x = F.relu(x)
|
|
x = F.relu(x, inplace=True)
|
|
x = x.relu()
|
|
x.relu_()
|
|
x = x.squeeze(0)
|
|
x.squeeze_(0)
|
|
x = torch.squeeze(x, 0)
|
|
x = x.unsqueeze(0)
|
|
x.unsqueeze_(0)
|
|
x = torch.unsqueeze(x, 0)
|
|
x = x.detach()
|
|
x.detach_()
|
|
x = x.repeat(4, 2)
|
|
y = []
|
|
y.append(x)
|
|
z = torch.stack(y, 0)
|
|
z = [z, z]
|
|
x, _ = z
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
example_inputs = (torch.rand(1, 3, 10, 10),)
|
|
# This model is not executable since we just put all ops
|
|
# in the same forward
|
|
m = M().eval()
|
|
qconfig_dict = {'': default_qconfig}
|
|
prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
# not runnable
|
|
quantized = convert_fx(prepared)
|
|
|
|
# This checks that the dequantize from the output of first conv
|
|
# is being propagated to the end, so that we don't insert extra
|
|
# observers and also successfully fused two quantized::conv2d
|
|
# patterns
|
|
# one quantize_per_tensor for input
|
|
# check exact counts of quantize and dequantize
|
|
count_check = {
|
|
# input of conv and two outputs of getitem
|
|
ns.call_function(torch.quantize_per_tensor) : 2,
|
|
# output of the model and two outputs of getitem
|
|
ns.call_method('dequantize') : 2
|
|
}
|
|
order_check = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method('dequantize'),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
quantized,
|
|
expected_node_occurrence=count_check,
|
|
expected_node_list=order_check)
|
|
|
|
|
|
# Checking the is_reference output
|
|
m = M().eval()
|
|
qconfig_dict = {'': default_qconfig}
|
|
prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
# not runnable
|
|
quantized = convert_to_reference_fx(prepared)
|
|
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_ave_pool_with_custom_cfg(self):
|
|
""" A test that checks correct patterns are produced for
|
|
avg_pool2d with customized config
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.avg_pool2d = torch.nn.AvgPool2d(3)
|
|
|
|
|
|
def forward(self, x):
|
|
x = self.avg_pool2d(x)
|
|
return x
|
|
|
|
# This model is not executable since we just put all ops
|
|
# in the same forward
|
|
m = M().eval()
|
|
# nothing to fuse so skipping the fuse step
|
|
qconfig_dict = {'': default_qconfig}
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
prepared = prepare_fx(
|
|
m, qconfig_dict, example_inputs=example_inputs,
|
|
prepare_custom_config={"input_quantized_idxs": [0]})
|
|
|
|
# not runnable
|
|
quantized = convert_fx(prepared)
|
|
|
|
# This checks that the dequantize from the output of first conv
|
|
# is being propagated to the end, so that we don't insert extra
|
|
# observers
|
|
# check exact counts of quantize and dequantize
|
|
count_check = {
|
|
ns.call_method('dequantize') : 1
|
|
}
|
|
order_check = [
|
|
ns.call_module(nn.AvgPool2d),
|
|
ns.call_method('dequantize'),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
quantized,
|
|
expected_node_occurrence=count_check,
|
|
expected_node_list=order_check)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_general_value_ops(self):
|
|
""" A test that checks correct patterns are produced for
|
|
all supported general value ops like aten::avg_pool2d \
|
|
without actually checking for execution of these ops
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.avg_pool1d = torch.nn.AvgPool1d(3)
|
|
self.avg_pool2d = torch.nn.AvgPool2d(3)
|
|
self.avg_pool3d = torch.nn.AvgPool3d(3)
|
|
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1)
|
|
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.avg_pool1d(x)
|
|
x = self.avg_pool2d(x)
|
|
x = self.avg_pool3d(x)
|
|
x = self.adaptive_avg_pool1d(x)
|
|
x = self.adaptive_avg_pool2d(x)
|
|
x = self.adaptive_avg_pool3d(x)
|
|
x = F.avg_pool1d(x, 3)
|
|
x = F.avg_pool2d(x, 3)
|
|
x = F.avg_pool3d(x, 3)
|
|
x = F.adaptive_avg_pool1d(x, (1))
|
|
x = F.adaptive_avg_pool2d(x, (1, 1))
|
|
x = F.adaptive_avg_pool3d(x, (1, 1, 1))
|
|
x = torch.mean(x)
|
|
x = torch.mean(x, [2, 3], False)
|
|
x = x.mean()
|
|
x = x.mean([2, 3], True)
|
|
x = F.interpolate(x, 4, mode='nearest')
|
|
x = F.interpolate(x, 4, mode='linear')
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
# This model is not executable since we just put all ops
|
|
# in the same forward
|
|
m = M().eval()
|
|
# nothing to fuse so skipping the fuse step
|
|
qconfig_dict = {'': default_qconfig}
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
# not runnable
|
|
quantized = convert_fx(prepared)
|
|
|
|
# This checks that the dequantize from the output of first conv
|
|
# is being propagated to the end, so that we don't insert extra
|
|
# observers
|
|
# check exact counts of quantize and dequantize
|
|
count_check = {
|
|
ns.call_function(torch.quantize_per_tensor) : 1,
|
|
ns.call_method('dequantize') : 1
|
|
}
|
|
order_check = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_method('dequantize'),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
quantized,
|
|
expected_node_occurrence=count_check,
|
|
expected_node_list=order_check)
|
|
|
|
def test_copy_node_fp32_input(self):
|
|
""" CopyNode works for both fp32 and int8 inputs, this is a test to make
|
|
sure that a CopyNode can be successfully quantized in both cases
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x.relu()
|
|
return x
|
|
|
|
m = M().eval()
|
|
m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),))
|
|
m = convert_fx(m)
|
|
# make sure it runs
|
|
m(torch.rand(1))
|
|
|
|
def test_getitem(self):
|
|
""" Make sure we only insert observer for getitem if the following node is matched
|
|
or needs to be quantized
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def forward(self, xs):
|
|
x = xs[0]
|
|
return x
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.rand(1, 2),)
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs)
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence={
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 0
|
|
})
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, xs):
|
|
x = xs[0]
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
m2 = M2().eval()
|
|
example_inputs = ([torch.rand(1, 2)],)
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
m2 = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs)
|
|
self.checkGraphModuleNodes(m2, expected_node_occurrence={
|
|
ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2
|
|
})
|
|
m2 = convert_fx(m2)
|
|
self.checkGraphModuleNodes(m2, expected_node_list=[
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize")
|
|
])
|
|
m2(*example_inputs)
|
|
|
|
# testing prepare recognizes non-Tensor input for getitem
|
|
class M3(torch.nn.Module):
|
|
def forward(self, x):
|
|
s = x.shape
|
|
n, c = s[:2]
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
m3 = M3().eval()
|
|
example_inputs = (torch.rand(1, 2, 3, 4),)
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
m3 = prepare_fx(m3, qconfig_mapping, example_inputs=example_inputs)
|
|
self.checkGraphModuleNodes(m3, expected_node_occurrence={
|
|
ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2
|
|
})
|
|
m3 = convert_fx(m3)
|
|
self.checkGraphModuleNodes(m3, expected_node_list=[
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method("dequantize")
|
|
])
|
|
m3(*example_inputs)
|
|
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_fixed_qparams_ops(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
self.hardsigmoid = torch.nn.Hardsigmoid()
|
|
self.tanh = torch.nn.Tanh()
|
|
self.softmax = torch.nn.Softmax(dim=0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
# F.sigmoid is deprecated
|
|
x = self.sigmoid(x)
|
|
x = torch.sigmoid(x)
|
|
x = x.sigmoid()
|
|
x = self.hardsigmoid(x)
|
|
x = F.hardsigmoid(x)
|
|
x = F.hardsigmoid(x, inplace=True)
|
|
x = self.tanh(x)
|
|
# F.tanh is deprecated
|
|
x = torch.tanh(x)
|
|
x = x.tanh()
|
|
# TODO(future PR): handle F.softmax
|
|
x = self.softmax(x)
|
|
return x
|
|
|
|
for eval_mode in [True, False]:
|
|
# This model is not executable since we just put all ops
|
|
# in the same forward
|
|
m = M()
|
|
if eval_mode:
|
|
m.eval()
|
|
qconfig_mapping = get_default_qconfig_mapping()
|
|
prepare = prepare_fx
|
|
fq_count = 10
|
|
else:
|
|
m.train()
|
|
qconfig_mapping = get_default_qat_qconfig_mapping()
|
|
prepare = prepare_qat_fx
|
|
fq_count = 10
|
|
# nothing to fuse so skipping the fuse step
|
|
m_copy = copy.deepcopy(m)
|
|
example_inputs = (torch.rand(3, 3, 3, 3),)
|
|
prepared = prepare(m, qconfig_mapping, example_inputs=example_inputs)
|
|
prepared_copy = copy.deepcopy(prepared)
|
|
# check that prepare does not change model result
|
|
if eval_mode:
|
|
self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs))
|
|
# check the correct number of activation_post_process is inserted
|
|
expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize
|
|
count_check = {
|
|
ns.call_module(expected_activation_post_process) : fq_count,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
prepared,
|
|
expected_node_occurrence=count_check)
|
|
# not runnable
|
|
quantized = convert_fx(prepared)
|
|
quantized_reference = convert_to_reference_fx(prepared_copy)
|
|
|
|
# This checks that the dequantize from the output of first conv
|
|
# is being propagated to the end, so that we don't insert extra
|
|
# observers
|
|
# check exact counts of quantize and dequantize
|
|
count_check = {
|
|
ns.call_function(torch.quantize_per_tensor) : 1,
|
|
ns.call_method('dequantize') : 1
|
|
}
|
|
order_check = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_module(nnq.Conv2d),
|
|
ns.call_module(nn.Sigmoid),
|
|
ns.call_module(nnq.Softmax),
|
|
ns.call_method('dequantize'),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
quantized,
|
|
expected_node_occurrence=count_check,
|
|
expected_node_list=order_check)
|
|
|
|
reference_count_check = {
|
|
ns.call_function(torch.quantize_per_tensor) : 12,
|
|
ns.call_method('dequantize') : 12
|
|
}
|
|
reference_order_check = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize'),
|
|
ns.call_module(nnqr.Conv2d),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize'),
|
|
ns.call_module(nn.Sigmoid),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize'),
|
|
ns.call_module(nn.Softmax),
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_method('dequantize'),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
quantized_reference,
|
|
expected_node_occurrence=reference_count_check,
|
|
expected_node_list=reference_order_check)
|
|
|
|
# Verify that softmax scale and zero_point are correct
|
|
self.assertTrue(quantized.softmax.scale - (1.0 / 256) <= 1e-8)
|
|
self.assertTrue(quantized.softmax.zero_point == 0)
|
|
|
|
def test_float_functional(self):
|
|
class TorchAdd(nn.Module):
|
|
"""Wrapper around torch.add so that all ops can be found at build"""
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.add_func = nnq.FloatFunctional()
|
|
|
|
def forward(self, x, y):
|
|
return self.add_func.add(x, y)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.ff1 = TorchAdd()
|
|
self.ff2 = nnq.FloatFunctional()
|
|
self.ff3 = nnq.FloatFunctional()
|
|
self.ff4 = nnq.FloatFunctional()
|
|
self.ff5 = nnq.FloatFunctional()
|
|
self.ff6 = nnq.FloatFunctional()
|
|
|
|
def forward(self, x):
|
|
x = self.ff1(x, x)
|
|
x = self.ff2.add_scalar(x, 3)
|
|
x = self.ff3.mul(x, x)
|
|
x = self.ff4.mul_scalar(x, 3)
|
|
x = self.ff5.add_relu(x, x)
|
|
x = self.ff6.cat([x])
|
|
return x
|
|
|
|
example_inputs = (torch.rand(3, 3),)
|
|
# Note: QAT test succeeded by chance, to make it actually work
|
|
# we need to fix eager mode FloatFunctional by removing
|
|
# activation_post_process in add_scalar and mul_scalar
|
|
for quant_type in self.static_quant_types:
|
|
m = M()
|
|
ref_m = torch.ao.quantization.QuantWrapper(M())
|
|
is_qat = quant_type == QuantType.QAT
|
|
if is_qat:
|
|
m.train()
|
|
ref_m.train()
|
|
qconfig = default_qat_qconfig
|
|
expected_act_post_process = torch.ao.quantization.FakeQuantize
|
|
else:
|
|
m.eval()
|
|
ref_m.eval()
|
|
qconfig = default_qconfig
|
|
expected_act_post_process = torch.ao.quantization.MinMaxObserver
|
|
|
|
prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx
|
|
qconfig_dict = {"": qconfig}
|
|
m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs)
|
|
node_occurrence = {
|
|
ns.call_module(expected_act_post_process): 7,
|
|
ns.call_module(torch.ao.nn.quantized.FloatFunctional): 0
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
m(*example_inputs)
|
|
node_list = [
|
|
ns.call_function(torch.quantize_per_tensor),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_function(torch.ops.quantized.add),
|
|
ns.call_function(torch.ops.quantized.mul),
|
|
ns.call_function(torch.ops.quantized.mul),
|
|
ns.call_function(torch.ops.quantized.add_relu),
|
|
ns.call_function(torch.cat),
|
|
ns.call_method('dequantize')
|
|
]
|
|
m = convert_fx(m)
|
|
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
|
|
|
# make sure numerics match with eager mode
|
|
ref_m.qconfig = qconfig
|
|
prepare_function = prepare_qat if is_qat else prepare
|
|
ref_m = prepare_function(ref_m)
|
|
ref_m(*example_inputs)
|
|
ref_m = convert(ref_m)
|
|
# FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar
|
|
# self.assertEqual(m(data), ref_m(data))
|
|
|
|
def test_embedding(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
|
|
|
|
def forward(self, indices):
|
|
return self.emb(indices)
|
|
|
|
for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]:
|
|
model = M().eval()
|
|
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
|
|
example_inputs = (indices,)
|
|
quantized_node = ns.call_module(nnq.Embedding)
|
|
|
|
# check dynamic quant
|
|
self.checkGraphModeFxOp(
|
|
model,
|
|
example_inputs,
|
|
QuantType.DYNAMIC,
|
|
quantized_node,
|
|
custom_qconfig_dict={"": qconfig_type}
|
|
)
|
|
model = M().eval()
|
|
|
|
configs = [
|
|
(qconfig_type, ns.call_module(nnq.Embedding)),
|
|
(None, ns.call_module(nn.Embedding)),
|
|
(default_qconfig, ns.call_module(nn.Embedding)),
|
|
]
|
|
|
|
# check static quantization
|
|
for qconfig, node in configs:
|
|
qconfig_dict = {"": qconfig}
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence={
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 0
|
|
})
|
|
m = convert_fx(m)
|
|
self.checkGraphModuleNodes(m, expected_node=node)
|
|
# make sure it runs
|
|
m(*example_inputs)
|
|
|
|
def test_embedding_bag(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True)
|
|
|
|
def forward(self, indices, offsets):
|
|
return self.emb(indices, offsets)
|
|
|
|
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
|
|
offsets = torch.tensor([0, 19, 20, 28, 28, 32])
|
|
quantized_node = ns.call_module(nnq.EmbeddingBag)
|
|
example_inputs = (indices, offsets)
|
|
|
|
for dtype in [torch.quint8, torch.quint4x2]:
|
|
model = M().eval()
|
|
float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
|
|
qscheme=torch.per_channel_affine_float_qparams,
|
|
ch_axis=0)
|
|
float_qparams_qconfig = QConfig(activation=default_placeholder_observer,
|
|
weight=float_qparams_observer)
|
|
self.checkGraphModeFxOp(
|
|
model,
|
|
example_inputs,
|
|
QuantType.DYNAMIC,
|
|
quantized_node,
|
|
custom_qconfig_dict={"": float_qparams_qconfig}
|
|
)
|
|
|
|
# check it works in None and static qconfig
|
|
for qconfig in [None, default_qconfig]:
|
|
qconfig_dict = {"": default_qconfig}
|
|
m = M().eval()
|
|
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence={
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 0
|
|
})
|
|
m = convert_fx(m)
|
|
self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag))
|
|
# make sure it runs
|
|
m(*example_inputs)
|
|
|
|
def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input):
|
|
options = itertools.product(qconfigs, module_type_strs)
|
|
for qconfig, module_type_str in options:
|
|
model_eager = M(module_type_str).eval()
|
|
model_graph = copy.deepcopy(model_eager)
|
|
if torch.backends.quantized.engine == 'qnnpack' and \
|
|
qconfig is float16_dynamic_qconfig:
|
|
continue
|
|
# fp16 dynamic quant is not supported for qnnpack
|
|
|
|
eager_qconfig_dict = dict.fromkeys(module_types, qconfig)
|
|
model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict)
|
|
|
|
graph_qconfig_dict = {
|
|
"object_type": [
|
|
(x, qconfig) for x in module_types
|
|
]
|
|
}
|
|
model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,))
|
|
model_graph = convert_fx(model_graph)
|
|
self.assertEqual(model_eager(sample_input), model_graph(sample_input))
|
|
self.checkScriptable(model_graph, [[sample_input]], True)
|
|
|
|
@override_qengines
|
|
def test_rnn_cell(self):
|
|
if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
|
|
return
|
|
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
|
|
module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']
|
|
module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell]
|
|
sample_input = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float)
|
|
self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input)
|
|
|
|
@override_qengines
|
|
def test_rnn(self):
|
|
if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
|
|
return
|
|
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
|
|
module_type_strs = ['LSTM', 'GRU']
|
|
module_types = [torch.nn.LSTM, torch.nn.GRU]
|
|
niter = 10
|
|
sample_input = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
|
|
self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input)
|
|
|
|
def _test_conv_transpose_impl(
|
|
self, float_cls: Callable, q_cls: Callable, data: torch.Tensor):
|
|
with override_quantized_engine('qnnpack'):
|
|
# Create fp32 versions of FX and Eager models
|
|
m1 = torch.nn.Sequential(float_cls(1, 1, 1))
|
|
m2 = torch.nn.Sequential(float_cls(1, 1, 1))
|
|
m2.load_state_dict(m1.state_dict())
|
|
m2 = torch.ao.quantization.QuantWrapper(m2)
|
|
# FX graph
|
|
result_dict = self.checkGraphModeFxOp(
|
|
m1, (data,), QuantType.STATIC,
|
|
expected_node_occurrence={
|
|
ns.call_module(q_cls): 1,
|
|
})
|
|
q_result1 = result_dict["quantized_output"]
|
|
# Eager
|
|
m2.qconfig = get_default_qconfig(torch.backends.quantized.engine)
|
|
m2.eval()
|
|
m2p = torch.ao.quantization.prepare(m2)
|
|
m2p(data)
|
|
m2q = torch.ao.quantization.convert(m2p)
|
|
q_result2 = m2q(data)
|
|
# verify results match
|
|
self.assertEqual(q_result1, q_result2)
|
|
|
|
@unittest.skipUnless('qnnpack' in supported_qengines,
|
|
"This Pytorch Build has not been built with or does not support QNNPACK")
|
|
def test_conv_transpose_1d(self):
|
|
self._test_conv_transpose_impl(
|
|
torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4))
|
|
|
|
@unittest.skipUnless('qnnpack' in supported_qengines,
|
|
"This Pytorch Build has not been built with or does not support QNNPACK")
|
|
def test_conv_transpose_2d(self):
|
|
self._test_conv_transpose_impl(
|
|
torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4))
|
|
|
|
def test_reshape_fp16(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, w, b):
|
|
super().__init__()
|
|
self.w = w
|
|
self.b = b
|
|
|
|
def forward(self, x):
|
|
x = torch.nn.functional.linear(x, self.w)
|
|
x = x.reshape(-1, 4)
|
|
x = torch.nn.functional.linear(x, self.w)
|
|
return x
|
|
|
|
w = torch.randn(4, 4)
|
|
b = torch.randn(4)
|
|
m = M(w, b).eval()
|
|
qconfig_dict = {
|
|
# reshape will be quantized to fp16 as requested by this qconfig
|
|
"": float16_static_qconfig,
|
|
"object_type": [
|
|
(torch.nn.functional.linear, default_qconfig)
|
|
]
|
|
}
|
|
backend_config = get_test_only_legacy_native_backend_config()
|
|
example_inputs = (torch.randn(1, 4),)
|
|
m = prepare_fx(
|
|
m, qconfig_dict, example_inputs=example_inputs,
|
|
backend_config=backend_config)
|
|
expected_occurrence = {
|
|
# input and weight of first and second linear, output of first and second linear
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 6,
|
|
# we insert placeholder observer for both input and output of reshape
|
|
ns.call_module(torch.ao.quantization.PlaceholderObserver): 2
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence
|
|
)
|
|
m = convert_fx(m, backend_config=backend_config)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
# dequantize after first linear, before reshape and before output
|
|
ns.call_method("dequantize"): 3,
|
|
# before reshape, to(fp16)
|
|
ns.call_method("to"): 1,
|
|
ns.call_function(torch.ops.quantized.linear): 2
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence
|
|
)
|
|
# make sure it runs
|
|
m(torch.randn(2, 4))
|
|
|
|
def test_multiple_qconfigs_for_single_value(self):
|
|
""" Test multiple qconfigs for a single value"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self, w, b):
|
|
super().__init__()
|
|
self.w = w
|
|
self.b = b
|
|
|
|
def forward(self, x):
|
|
x = torch.nn.functional.linear(x, self.w)
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
w = torch.randn(4, 4)
|
|
b = torch.randn(4)
|
|
m = M(w, b).eval()
|
|
# TODO: use get_default_qconfig_mapping once it handles fp16
|
|
qconfig_mapping = QConfigMapping() \
|
|
.set_global(float16_static_qconfig) \
|
|
.set_object_type(torch.nn.functional.linear, default_qconfig)
|
|
example_inputs = (torch.randn(1, 4),)
|
|
backend_config = get_test_only_legacy_native_backend_config()
|
|
m = prepare_fx(
|
|
m, qconfig_mapping, example_inputs=example_inputs,
|
|
backend_config=backend_config)
|
|
expected_occurrence = {
|
|
# input and weight of linear, output of linear
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
|
|
# input and output of sigmoid
|
|
ns.call_module(torch.ao.quantization.PlaceholderObserver): 2,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence
|
|
)
|
|
# make sure it runs
|
|
m = convert_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_method("dequantize"): 3,
|
|
ns.call_method("to"): 2
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence
|
|
)
|
|
|
|
def test_boolean_tensor(self):
|
|
""" Make sure we don't insert observer for boolean Tensors """
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, mask):
|
|
mask = mask.unsqueeze(0)
|
|
mask = mask.unsqueeze(1)
|
|
x = x.masked_fill(mask, 1)
|
|
return x
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool())
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
expected_occurrence = {
|
|
ns.call_module(torch.ao.quantization.MinMaxObserver): 0
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
|
|
def test_chunk(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
x, y = torch.chunk(x, 2)
|
|
x = x + y
|
|
return x
|
|
m = M().eval()
|
|
example_inputs = (torch.rand(2, 2, 2, 2),)
|
|
m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
|
|
m(*example_inputs)
|
|
m = convert_fx(m)
|
|
m(*example_inputs)
|
|
# make sure everything runs
|
|
|
|
def test_ref_pattern_multi_use(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
y = self.linear(x)
|
|
z = self.linear1(x)
|
|
a = torch.mul(z, 5)
|
|
b = torch.add(z, 5)
|
|
return (y, a, b)
|
|
|
|
m = M().eval()
|
|
qconfig_dict = {
|
|
"": None,
|
|
"object_type": [
|
|
(torch.nn.Linear, get_default_qconfig("fbgemm")),
|
|
(torch.nn.ReLU, get_default_qconfig("fbgemm")),
|
|
],
|
|
}
|
|
example_inputs = (torch.randn(1, 5),)
|
|
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
m = convert_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 1,
|
|
ns.call_module(nnq.Linear): 2,
|
|
ns.call_method("dequantize"): 2,
|
|
ns.call_function(torch.add): 1,
|
|
ns.call_function(torch.mul): 1,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_qmatmul(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
z = torch.matmul(x, y)
|
|
return z
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(2, 2), torch.randn(2, 2))
|
|
qconfig_dict = get_default_qconfig_mapping("fbgemm")
|
|
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
|
|
mp(*example_inputs)
|
|
mq = convert_fx(mp)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.matmul): 0,
|
|
ns.call_function(torch.ops.quantized.matmul): 1,
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
mq,
|
|
expected_node_occurrence=expected_occurrence)
|
|
# verify no crash
|
|
res = mq(*example_inputs)
|
|
|
|
def test_pixel_shuffle(self):
|
|
class MyBias(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bias = nn.Parameter(torch.randn(8))
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(8, 8, 1, bias=False)
|
|
self.bias = MyBias()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = nn.functional.pixel_shuffle(x, 2)
|
|
x = x.view(-1, 8, 2, 2)
|
|
bias = self.bias.bias
|
|
return x + bias
|
|
|
|
backend_config = get_qnnpack_backend_config()
|
|
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
|
|
model = MyModel()
|
|
m = prepare_fx(
|
|
model,
|
|
qconfig_mapping=qconfig_mapping,
|
|
example_inputs=(torch.randn(1, 8, 3, 3),),
|
|
backend_config=backend_config
|
|
)
|
|
m = convert_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_pixel_shuffle_module(self) -> None:
|
|
class MyBias(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bias = nn.Parameter(torch.randn(8))
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(8, 8, 1, bias=False)
|
|
self.ps = nn.PixelShuffle(upscale_factor=2)
|
|
self.bias = MyBias()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.ps(x)
|
|
x = x.view(-1, 8, 2, 2)
|
|
bias = self.bias.bias
|
|
return x + bias
|
|
|
|
backend_config = get_qnnpack_backend_config()
|
|
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
|
|
model = MyModel()
|
|
m = prepare_fx(
|
|
model,
|
|
qconfig_mapping=qconfig_mapping,
|
|
example_inputs=(torch.randn(1, 8, 3, 3),),
|
|
backend_config=backend_config
|
|
)
|
|
m = convert_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 1,
|
|
ns.call_module(nn.PixelShuffle): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_pixel_unshuffle(self):
|
|
class MyBias(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bias = nn.Parameter(torch.randn(64))
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(8, 8, 1, bias=False)
|
|
self.bias = MyBias()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = nn.functional.pixel_unshuffle(x, 2)
|
|
bias = self.bias.bias
|
|
return x + bias
|
|
|
|
for backend in ["fbgemm", "qnnpack"]:
|
|
if backend == "fbgemm":
|
|
backend_config = get_fbgemm_backend_config()
|
|
else:
|
|
backend_config = get_qnnpack_backend_config()
|
|
qconfig_mapping = get_default_qconfig_mapping(backend)
|
|
model = MyModel()
|
|
m = prepare_fx(
|
|
model,
|
|
qconfig_mapping=qconfig_mapping,
|
|
example_inputs=(torch.randn(1, 8, 6, 6),),
|
|
backend_config=backend_config
|
|
)
|
|
m = convert_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
|
|
|
|
def test_pixel_unshuffle_module(self) -> None:
|
|
class MyBias(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bias = nn.Parameter(torch.randn(64))
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(8, 8, 1, bias=False)
|
|
self.unshuffle = nn.PixelUnshuffle(downscale_factor=2)
|
|
self.bias = MyBias()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.unshuffle(x)
|
|
bias = self.bias.bias
|
|
return x + bias
|
|
|
|
for backend in ["fbgemm", "qnnpack"]:
|
|
if backend == "fbgemm":
|
|
backend_config = get_fbgemm_backend_config()
|
|
else:
|
|
backend_config = get_qnnpack_backend_config()
|
|
qconfig_mapping = get_default_qconfig_mapping(backend)
|
|
model = MyModel()
|
|
m = prepare_fx(
|
|
model,
|
|
qconfig_mapping=qconfig_mapping,
|
|
example_inputs=(torch.randn(1, 8, 6, 6),),
|
|
backend_config=backend_config
|
|
)
|
|
m = convert_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 1,
|
|
ns.call_module(nn.PixelUnshuffle): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
|
|
|
|
|
|
|
|
def test_narrow(self):
|
|
class MyBias(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bias = nn.Parameter(torch.randn(4))
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(8, 8, 1, bias=False)
|
|
self.bias = MyBias()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = torch.narrow(x, 1, 0, 4)
|
|
bias = self.bias.bias
|
|
return x + bias
|
|
|
|
for backend in ["fbgemm", "qnnpack"]:
|
|
if backend == "fbgemm":
|
|
backend_config = get_fbgemm_backend_config()
|
|
else:
|
|
backend_config = get_qnnpack_backend_config()
|
|
qconfig_mapping = get_default_qconfig_mapping(backend)
|
|
model = MyModel()
|
|
m = prepare_fx(
|
|
model,
|
|
qconfig_mapping=qconfig_mapping,
|
|
example_inputs=(torch.randn(1, 8, 3, 3),),
|
|
backend_config=backend_config
|
|
)
|
|
m = convert_fx(m)
|
|
expected_occurrence = {
|
|
ns.call_function(torch.quantize_per_tensor): 2,
|
|
ns.call_method("dequantize"): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
|
|
|
|
class TestQuantizeFxModels(QuantizationTestCase):
|
|
@skipIfNoFBGEMM
|
|
@unittest.skipIf(not TEST_CUDA, "gpu is not available.")
|
|
def test_static_gpu_convert_basic(self):
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu1 = nn.ReLU()
|
|
self.conv1 = nn.Conv2d(1, 6, 5)
|
|
self.linear1 = nn.Linear(120, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.relu1(self.conv1(x))
|
|
y = self.linear1(x.view(-1))
|
|
return y
|
|
|
|
input = torch.randn((5, 1, 6, 6)).to('cuda')
|
|
example_inputs = (input,)
|
|
model = Net().to('cuda').eval()
|
|
qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
|
|
model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
model_prepared(*example_inputs)
|
|
model_quantized = convert_to_reference_fx(model_prepared)
|
|
out = model_quantized(*example_inputs)
|
|
self.assertEqual(out.device.type, 'cuda')
|
|
|
|
@skipIfNoFBGEMM
|
|
@unittest.skipIf(not TEST_CUDA, "gpu is not available.")
|
|
def test_switch_device_prepare_convert(self):
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu1 = nn.ReLU()
|
|
self.conv1 = nn.Conv2d(1, 6, 5)
|
|
self.linear1 = nn.Linear(120, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.relu1(self.conv1(x))
|
|
y = self.linear1(x.view(-1))
|
|
return y
|
|
|
|
for device in ['cuda', 'cpu']:
|
|
device_after = 'cuda' if device == 'cpu' else 'cpu'
|
|
input = torch.randn((5, 1, 6, 6)).to(device)
|
|
model = Net().to(device).eval()
|
|
qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
|
|
model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,))
|
|
model_prepared(input)
|
|
model_prepared.to(device_after)
|
|
model_quantized = convert_to_reference_fx(model_prepared)
|
|
out = model_quantized(input.to(device_after))
|
|
self.assertEqual(out.device.type, device_after)
|
|
|
|
@skipIfNoFBGEMM
|
|
@unittest.skipIf(not TEST_CUDA, "gpu is not available.")
|
|
def test_prepare_serialize_switch_device_convert(self):
|
|
class Net(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 6, 5)
|
|
self.linear1 = nn.Linear(120, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
y = self.linear1(x.view(-1))
|
|
return y
|
|
|
|
for device in ['cuda', 'cpu']:
|
|
for device_after in ['cuda', 'cpu']:
|
|
input = torch.randn((5, 1, 6, 6)).to(device)
|
|
model = Net().to(device).eval()
|
|
qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
|
|
model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,))
|
|
model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,))
|
|
model_prepared_first(input)
|
|
state_dict = model_prepared_first.state_dict()
|
|
del model_prepared_first
|
|
model_prepared_second.load_state_dict(state_dict)
|
|
model_prepared_second.to(device_after)
|
|
model_quantized = convert_to_reference_fx(model_prepared_second)
|
|
out = model_quantized(input.to(device_after))
|
|
self.assertEqual(out.device.type, device_after)
|
|
|
|
@skipIfTorchDynamo("too slow")
|
|
@skip_if_no_torchvision
|
|
def test_model_dropout(self):
|
|
from torchvision import models
|
|
m = models.mobilenet_v3_small()
|
|
qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('fbgemm')
|
|
example_inputs = (torch.randn(1, 3, 224, 224),)
|
|
mp = prepare_qat_fx(m, qconfig_mapping, example_inputs=example_inputs)
|
|
mp(*example_inputs)
|
|
with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext():
|
|
mq = convert_fx(mp)
|
|
mq(*example_inputs)
|
|
|
|
def _test_model_impl(
|
|
self, mode, name, model, eager_quantizable_model,
|
|
check_with_eager=True,
|
|
diff_of_quant=None,
|
|
diff_from_eager=None):
|
|
if diff_of_quant is None or diff_from_eager is None:
|
|
diff_of_quant = {}
|
|
diff_from_eager = {}
|
|
|
|
if mode not in diff_of_quant or mode not in diff_from_eager:
|
|
diff_of_quant[mode] = {}
|
|
diff_from_eager[mode] = {}
|
|
|
|
input_tensor = torch.rand(1, 3, 224, 224)
|
|
input_tensor_inception = torch.rand(1, 3, 299, 299)
|
|
output_value = torch.randint(0, 1, (1,))
|
|
|
|
# print('quantizing:', name, ' mode:', mode)
|
|
if name == 'inception_v3':
|
|
input_value = input_tensor_inception
|
|
else:
|
|
input_value = input_tensor
|
|
|
|
qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
|
|
qconfig_dict = {'': qconfig}
|
|
script = torch.jit.script(model)
|
|
|
|
# make sure graph module and script module are both runanble
|
|
original_out = model(input_value)
|
|
is_not_tuple_out = not isinstance(original_out, tuple)
|
|
script_out = script(input_value)
|
|
|
|
# set to train just before quantization
|
|
prepare_fx_fn = prepare_fx
|
|
if mode != 'static':
|
|
model.train()
|
|
prepare_fx_fn = prepare_qat_fx
|
|
|
|
prepared = prepare_fx_fn(model, qconfig_dict)
|
|
|
|
if mode == 'ddp':
|
|
mp.spawn(run_ddp,
|
|
args=(world_size, prepared), # noqa: F821
|
|
nprocs=world_size, # noqa: F821
|
|
join=True)
|
|
elif mode == 'qat':
|
|
assert prepared.training, 'prepared must be in training mode for qat'
|
|
optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
|
|
criterion = nn.CrossEntropyLoss()
|
|
train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
|
|
else:
|
|
for i in range(10):
|
|
prepared(input_value)
|
|
|
|
# print('after observation root:', prepared.root)
|
|
|
|
qgraph = convert_fx(prepared)
|
|
# print('after quantization root:', qgraph.root)
|
|
# print('after quantization code:', qgraph.src)
|
|
qgraph.eval()
|
|
qgraph_script = torch.jit.script(qgraph)
|
|
# print('quantized and scripted:', qgraph_script.graph)
|
|
|
|
qgraph_out = qgraph(input_value)
|
|
qgraph_script = qgraph_script(input_value)
|
|
|
|
if is_not_tuple_out:
|
|
diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
|
|
assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph'
|
|
else:
|
|
print('tuple output')
|
|
|
|
if eager_quantizable_model is not None:
|
|
# comparing to eager mode quantization
|
|
qeager = eager_quantizable_model
|
|
ref_out = qeager(input_value)
|
|
qeager.qconfig = qconfig
|
|
if mode == 'static':
|
|
qeager.fuse_model()
|
|
prepare(qeager, inplace=True)
|
|
else:
|
|
qeager.train()
|
|
qeager.fuse_model()
|
|
prepare_qat(qeager, inplace=True)
|
|
|
|
# calibration
|
|
if mode == 'ddp':
|
|
mp.spawn(run_ddp,
|
|
args=(world_size, qeager), # noqa: F821
|
|
nprocs=world_size, # noqa: F821
|
|
join=True)
|
|
elif mode == 'qat':
|
|
assert qeager.training, 'qeager should be in training mode for qat'
|
|
optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
|
|
train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
|
|
else:
|
|
for i in range(10):
|
|
qeager(input_value)
|
|
|
|
# print('ref after observation:', qeager)
|
|
|
|
convert(qeager, inplace=True)
|
|
qeager.eval()
|
|
|
|
# print('ref after quantization:', qeager)
|
|
qeager_out = qeager(input_value)
|
|
qeager_script = torch.jit.script(qeager)
|
|
qscript_out = qeager_script(input_value)
|
|
if is_not_tuple_out:
|
|
diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max()
|
|
if check_with_eager:
|
|
self.assertEqual(diff_from_eager[mode][name], 0,
|
|
'Result of graph mode quantization and ' +
|
|
'eager mode quantization on model: ' + name +
|
|
' should match. Mode: ' + mode +
|
|
' diff:' + str(diff_from_eager[mode][name]))
|
|
|
|
def _test_building_block(self, quant_type, BB):
|
|
eager = BB().float()
|
|
graph = copy.deepcopy(eager)
|
|
|
|
if quant_type == QuantType.STATIC:
|
|
qconfig = default_qconfig
|
|
eager_prepare = prepare
|
|
graph_prepare = prepare_fx
|
|
eager.eval()
|
|
graph.eval()
|
|
calibrate_or_train = test_only_eval_fn
|
|
data = self.img_data_2d
|
|
is_qat = False
|
|
else:
|
|
assert quant_type == QuantType.QAT
|
|
qconfig = default_qat_qconfig
|
|
eager_prepare = prepare_qat
|
|
graph_prepare = prepare_qat_fx
|
|
eager.train()
|
|
graph.train()
|
|
calibrate_or_train = test_only_train_fn
|
|
data = self.img_data_2d_train
|
|
is_qat = True
|
|
|
|
if hasattr(eager, "fuse_model"):
|
|
eager.fuse_model()
|
|
eager = QuantWrapper(eager)
|
|
eager.qconfig = qconfig
|
|
eager = eager_prepare(eager)
|
|
|
|
qconfig_dict = {"": qconfig}
|
|
graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],))
|
|
|
|
eager_out = eager(data[0][0])
|
|
graph_out = graph(data[0][0])
|
|
# Eager Mode and FX Graph Mode QAT now differ in numerics both
|
|
# in Post Training and QAT because FX Graph Mode uses same fake_quant instances
|
|
# for input and output of CopyNode
|
|
# self.assertEqual(eager_out, graph_out)
|
|
|
|
calibrate_or_train(eager, data)
|
|
calibrate_or_train(graph, data)
|
|
|
|
eager = convert(eager)
|
|
graph = convert_fx(graph)
|
|
|
|
eager_out = eager(data[0][0])
|
|
graph_out = graph(data[0][0])
|
|
|
|
@override_qengines
|
|
def test_resnet_base(self):
|
|
models = [ResNetBase]
|
|
options = itertools.product(self.static_quant_types, models)
|
|
for quant_type, M in options:
|
|
self._test_building_block(quant_type, M)
|
|
|
|
@skip_if_no_torchvision
|
|
@skipIfNoFBGEMM
|
|
@unittest.skip("skip for now since tbb failed")
|
|
def test_torchvision(self):
|
|
from torchvision import models
|
|
from torchvision.models import quantization as quantized_models
|
|
from torchvision.models.quantization.utils import _replace_relu
|
|
|
|
def get_available_classification_models(models):
|
|
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
|
|
|
|
model_list = get_available_classification_models(models)
|
|
quantized_model_list = get_available_classification_models(quantized_models)
|
|
|
|
quantized_model_list = set(quantized_model_list)
|
|
# test eager and graph consistency
|
|
model_list = quantized_model_list
|
|
# mobilenet/inception_v3/googlenet qat is not working due to AdaptiveAveragePool qat
|
|
# we might observe the output of AdaptiveAveragePool in the future
|
|
# and re-enable the test
|
|
fx_eager_not_matching = [
|
|
("mobilenet_v2", "qat"),
|
|
("inception_v3", "qat"),
|
|
("googlenet", "qat")
|
|
] # because relu6 is replaced as relu in mobilenetv2
|
|
|
|
diff_of_quant = {}
|
|
diff_from_eager = {}
|
|
modes = ['static', 'qat']
|
|
options = itertools.product(modes, model_list)
|
|
for mode, name in options:
|
|
pretrained = name in quantized_model_list # load pretrained model to compare with quantized model
|
|
kwargs = {}
|
|
# turn off transform input for inception_v3 since
|
|
# it's not quantized in eager mode and in fx graph
|
|
# mode we can't skip quantizing a method right now
|
|
# (might be supported in the future)
|
|
if name in ["inception_v3", "googlenet"]:
|
|
kwargs["transform_input"] = False
|
|
eager_quantizable_model = None
|
|
if name in quantized_model_list:
|
|
eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False, **kwargs).eval().float()
|
|
# compare with eager mode quantized model when it is available
|
|
pretrained = eager_quantizable_model is not None
|
|
model = models.__dict__[name](pretrained=pretrained, **kwargs).eval().float()
|
|
if name == "mobilenet_v2":
|
|
_replace_relu(model)
|
|
# disable aux logits
|
|
if hasattr(model, "aux_logits"):
|
|
model.aux_logits = False
|
|
model.AuxLogits = None
|
|
if eager_quantizable_model:
|
|
eager_quantizable_model.aux_logits = False
|
|
eager_quantizable_model.AuxLogits = None
|
|
|
|
check_with_eager = (name, mode) not in fx_eager_not_matching
|
|
self._test_model_impl(
|
|
mode, name, model, eager_quantizable_model,
|
|
check_with_eager,
|
|
diff_of_quant, diff_from_eager)
|
|
|
|
def print_diffs(diffs):
|
|
for mode, diffs_for_mode in diffs.items():
|
|
print('mode:', mode)
|
|
for name, diff in diffs_for_mode.items():
|
|
print(name, ':', diff)
|
|
|
|
# print('differences between float and quantized')
|
|
# print_diffs(diff_of_quant)
|
|
# print('----------------------')
|
|
# print('differences between graph mode and eager mode')
|
|
# print_diffs(diff_from_eager)
|
|
# print('----------------------')
|
|
|
|
@skip_if_no_torchvision
|
|
@skipIfNoFBGEMM
|
|
@unittest.skip("TODO: Test is always failing - https://github.com/pytorch/pytorch/issues/54979")
|
|
def test_resnet18_ddp(self):
|
|
from torchvision import models
|
|
from torchvision.models import quantization as quantized_models
|
|
eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float() # noqa: F821
|
|
model = models.__dict__[name](pretrained=False).eval().float() # noqa: F821
|
|
self._test_model_impl(
|
|
'ddp', 'resnet18', model, eager_quantizable_model)
|
|
|
|
@override_qengines
|
|
def test_qat_embeddingbag_linear(self):
|
|
for device in get_supported_device_types():
|
|
class EmbeddingBagLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum')
|
|
self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float)
|
|
|
|
def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None,
|
|
per_sample_weights: Optional[torch.Tensor] = None):
|
|
x = self.emb(input, offsets, per_sample_weights)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
qengine = torch.backends.quantized.engine
|
|
qconfig_dict = QConfigMapping() \
|
|
.set_global(get_default_qat_qconfig(qengine)) \
|
|
.set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig)
|
|
|
|
train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)]
|
|
eval_output = [[torch.randint(0, 10, (12, 1))]]
|
|
|
|
model = EmbeddingBagLinear().train()
|
|
prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],))
|
|
test_only_train_fn(prepared_fx_model, train_indices)
|
|
quant_model = convert_fx(prepared_fx_model,
|
|
qconfig_mapping=qconfig_dict)
|
|
|
|
def checkQuantized(model):
|
|
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
|
|
self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
|
|
# Also test that Linear has been quantized.
|
|
self.assertTrue(type(model.linear), nnq.Linear)
|
|
|
|
test_only_eval_fn(model, eval_output)
|
|
self.checkScriptable(model, eval_output)
|
|
self.checkNoQconfig(model)
|
|
checkQuantized(quant_model)
|
|
|
|
|
|
@override_qengines
|
|
def test_qat_embedding_linear(self):
|
|
for device in get_supported_device_types():
|
|
class EmbeddingLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
|
|
self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float)
|
|
|
|
def forward(self, input: torch.Tensor):
|
|
x = torch.sum(self.emb(input), dim=1)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
qengine = torch.backends.quantized.engine
|
|
qconfig_dict = {"": get_default_qat_qconfig(qengine),
|
|
"object_type": [(torch.nn.Embedding, default_embedding_qat_qconfig)]}
|
|
|
|
|
|
train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)]
|
|
eval_output = [[torch.randint(0, 10, (12, 1))]]
|
|
|
|
model = EmbeddingLinear().train()
|
|
prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],))
|
|
test_only_train_fn(prepared_fx_model, train_indices)
|
|
quant_model = convert_fx(prepared_fx_model,
|
|
qconfig_mapping=qconfig_dict)
|
|
|
|
def checkQuantized(model):
|
|
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
|
|
self.assertTrue(type(model.emb), nn.quantized.Embedding)
|
|
# Also test that Linear has been quantized.
|
|
self.assertTrue(type(model.linear), nnq.Linear)
|
|
|
|
test_only_eval_fn(model, eval_output)
|
|
self.checkScriptable(model, eval_output)
|
|
self.checkNoQconfig(model)
|
|
checkQuantized(quant_model)
|
|
|
|
@given(
|
|
device=st.sampled_from(
|
|
["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
|
)
|
|
)
|
|
@settings(deadline=None)
|
|
@override_qengines
|
|
def test_qat_functional_linear(self, device):
|
|
if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
|
|
return
|
|
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.ones(5, 5)
|
|
self.b = torch.zeros(5)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.w, self.b)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mods1 = torch.nn.Sequential(Linear(), Linear())
|
|
self.mods2 = Linear()
|
|
|
|
def forward(self, x):
|
|
x = self.mods1(x)
|
|
x = self.mods2(x)
|
|
return x
|
|
|
|
model = M().train()
|
|
ref_fake_quant = FakeQuantize.with_args(
|
|
observer=MovingAverageMinMaxObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
dtype=torch.quint8,
|
|
reduce_range=False,
|
|
)
|
|
ref_weight_fake_quant = FakeQuantize.with_args(
|
|
observer=MovingAverageMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
reduce_range=False,
|
|
)
|
|
ref_qat_qconfig = QConfig(
|
|
activation=ref_fake_quant, weight=ref_weight_fake_quant
|
|
)
|
|
qconfig_dict = {"": ref_qat_qconfig}
|
|
example_inputs = (torch.randn(1, 5),)
|
|
prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
|
|
|
|
custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
|
|
observer=MovingAverageMinMaxObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
dtype=torch.quint8,
|
|
reduce_range=False,
|
|
)
|
|
custom_weight_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
|
|
observer=MovingAverageMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
reduce_range=False,
|
|
)
|
|
custom_qconfig = QConfig(
|
|
activation=custom_fake_quant, weight=custom_weight_fake_quant
|
|
)
|
|
custom_qconfig_dict = {"": custom_qconfig}
|
|
prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs)
|
|
|
|
prepared.to(device)
|
|
prepared_ref.to(device)
|
|
|
|
prepared.apply(torch.ao.quantization.disable_fake_quant)
|
|
prepared.apply(torch.ao.quantization.disable_observer)
|
|
prepared_ref.apply(torch.ao.quantization.disable_fake_quant)
|
|
prepared_ref.apply(torch.ao.quantization.disable_observer)
|
|
|
|
inp = torch.randn(5, 5, device=device, requires_grad=True)
|
|
for i in range(10):
|
|
if i == 2:
|
|
prepared.apply(torch.ao.quantization.enable_observer)
|
|
prepared_ref.apply(torch.ao.quantization.enable_observer)
|
|
if i == 4:
|
|
prepared.apply(torch.ao.quantization.enable_fake_quant)
|
|
prepared_ref.apply(torch.ao.quantization.enable_fake_quant)
|
|
|
|
inp = torch.randn(5, 5, device=device, requires_grad=True)
|
|
out_ref = prepared_ref(inp)
|
|
out = prepared(inp)
|
|
torch.testing.assert_close(out, out_ref)
|
|
|
|
# try backward pass
|
|
labels = torch.randn(5, 5, device=device)
|
|
loss = (out - labels).sum()
|
|
grad = torch.autograd.grad(loss, [inp])
|
|
loss_ref = (out_ref - labels).sum()
|
|
grad_ref = torch.autograd.grad(loss_ref, [inp])
|
|
torch.testing.assert_close(grad[0], grad_ref[0])
|
|
|
|
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
|
# During the lowering step in convert, fold_weight calls quantized::linear_prepack
|
|
# which doesn't support QuantizedCuda backend
|
|
prepared.cpu()
|
|
prepared_ref.cpu()
|
|
converted = convert_fx(prepared)
|
|
converted_ref = convert_fx(prepared_ref)
|
|
inp = torch.rand(5, 5)
|
|
out = converted(inp)
|
|
out_ref = converted_ref(inp)
|
|
|
|
torch.testing.assert_close(out, out_ref)
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
|
"instead.")
|