mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][graphmode][fx] Produce conv reference static quant modules (#60138)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60138 Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: vkuzo Differential Revision: D29184791 fbshipit-source-id: 971a40012dbba0cf687c62a3a4af9358513c253b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b298013cd5
commit
47d727fe1b
@ -497,7 +497,7 @@ class ConvReluQuantizeHandler(QuantizeHandler):
|
||||
self.conv.activation_post_process = output_activation_post_process
|
||||
# 2. select quantized class
|
||||
qconv_cls = get_static_quant_module_class(
|
||||
type(self.conv), additional_static_quant_mapping)
|
||||
type(self.conv), additional_static_quant_mapping, is_reference=is_reference)
|
||||
quantized = qconv_cls.from_float(self.conv)
|
||||
parent_name, name = _parent_name(self.conv_node.target)
|
||||
setattr(modules[parent_name], name, quantized)
|
||||
|
@ -6,8 +6,10 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.quantized as nniq
|
||||
import torch.nn.intrinsic.quantized._reference as nniqr
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized._reference as nnqr
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
import torch.nn.qat as nnqat
|
||||
|
||||
@ -20,6 +22,31 @@ from .fake_quantize import (
|
||||
)
|
||||
from .utils import get_combined_dict
|
||||
|
||||
# Default map for swapping float module to reference quantized modules
|
||||
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||
nn.Conv1d: nnqr.Conv1d,
|
||||
nn.Conv2d: nnqr.Conv2d,
|
||||
nn.Conv3d: nnqr.Conv3d,
|
||||
# nn.Linear, nnqr.Linear,
|
||||
nni.ConvReLU1d: nniqr.ConvReLU1d,
|
||||
nni.ConvReLU2d: nniqr.ConvReLU2d,
|
||||
nni.ConvReLU3d: nniqr.ConvReLU3d,
|
||||
nni.LinearReLU: nniqr.LinearReLU,
|
||||
# QAT Modules
|
||||
# nnqat.Linear: nnqr.Linear,
|
||||
nnqat.Conv2d: nnqr.Conv2d,
|
||||
nnqat.Conv3d: nnqr.Conv3d,
|
||||
nniqat.ConvBn1d: nnqr.Conv1d,
|
||||
nniqat.ConvBn2d: nnqr.Conv2d,
|
||||
nniqat.ConvBn3d: nnqr.Conv3d,
|
||||
nniqat.ConvBnReLU1d: nniqr.ConvReLU1d,
|
||||
nniqat.ConvBnReLU2d: nniqr.ConvReLU2d,
|
||||
nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
|
||||
nniqat.ConvReLU2d: nniqr.ConvReLU2d,
|
||||
nniqat.ConvReLU3d: nniqr.ConvReLU3d,
|
||||
# nniqat.LinearReLU: nniqr.LinearReLU,
|
||||
}
|
||||
|
||||
# Default map for swapping float module to quantized ones
|
||||
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||
QuantStub: nnq.Quantize,
|
||||
@ -134,13 +161,16 @@ def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
|
||||
|
||||
def get_static_quant_module_class(
|
||||
float_module_class: Callable,
|
||||
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
|
||||
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
|
||||
is_reference: bool = False) -> Any:
|
||||
r"""n Get the statically quantized module class corresponding to
|
||||
the floating point module class
|
||||
"""
|
||||
if additional_static_quant_mapping is None:
|
||||
additional_static_quant_mapping = {}
|
||||
all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
|
||||
all_mappings = get_combined_dict(
|
||||
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
|
||||
else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
|
||||
static_quant_module_class = all_mappings.get(float_module_class, None)
|
||||
assert static_quant_module_class is not None, \
|
||||
"Floating point module class {}".format(str(float_module_class)) + \
|
||||
|
Reference in New Issue
Block a user