[quant][graphmode][fx] Produce conv reference static quant modules (#60138)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60138

Test Plan:
python test/test_quantization.py TestQuantizeFx

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D29184791

fbshipit-source-id: 971a40012dbba0cf687c62a3a4af9358513c253b
This commit is contained in:
Jerry Zhang
2021-06-20 19:24:42 -07:00
committed by Facebook GitHub Bot
parent b298013cd5
commit 47d727fe1b
4 changed files with 108 additions and 23 deletions

View File

@ -497,7 +497,7 @@ class ConvReluQuantizeHandler(QuantizeHandler):
self.conv.activation_post_process = output_activation_post_process
# 2. select quantized class
qconv_cls = get_static_quant_module_class(
type(self.conv), additional_static_quant_mapping)
type(self.conv), additional_static_quant_mapping, is_reference=is_reference)
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(modules[parent_name], name, quantized)

View File

@ -6,8 +6,10 @@ from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized._reference as nniqr
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized._reference as nnqr
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
@ -20,6 +22,31 @@ from .fake_quantize import (
)
from .utils import get_combined_dict
# Default map for swapping float module to reference quantized modules
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.Conv1d: nnqr.Conv1d,
nn.Conv2d: nnqr.Conv2d,
nn.Conv3d: nnqr.Conv3d,
# nn.Linear, nnqr.Linear,
nni.ConvReLU1d: nniqr.ConvReLU1d,
nni.ConvReLU2d: nniqr.ConvReLU2d,
nni.ConvReLU3d: nniqr.ConvReLU3d,
nni.LinearReLU: nniqr.LinearReLU,
# QAT Modules
# nnqat.Linear: nnqr.Linear,
nnqat.Conv2d: nnqr.Conv2d,
nnqat.Conv3d: nnqr.Conv3d,
nniqat.ConvBn1d: nnqr.Conv1d,
nniqat.ConvBn2d: nnqr.Conv2d,
nniqat.ConvBn3d: nnqr.Conv3d,
nniqat.ConvBnReLU1d: nniqr.ConvReLU1d,
nniqat.ConvBnReLU2d: nniqr.ConvReLU2d,
nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
nniqat.ConvReLU2d: nniqr.ConvReLU2d,
nniqat.ConvReLU3d: nniqr.ConvReLU3d,
# nniqat.LinearReLU: nniqr.LinearReLU,
}
# Default map for swapping float module to quantized ones
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
QuantStub: nnq.Quantize,
@ -134,13 +161,16 @@ def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
def get_static_quant_module_class(
float_module_class: Callable,
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
is_reference: bool = False) -> Any:
r"""n Get the statically quantized module class corresponding to
the floating point module class
"""
if additional_static_quant_mapping is None:
additional_static_quant_mapping = {}
all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
all_mappings = get_combined_dict(
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \