[quant][graphmode][fx] Produce conv reference static quant modules (#60138)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60138

Test Plan:
python test/test_quantization.py TestQuantizeFx

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D29184791

fbshipit-source-id: 971a40012dbba0cf687c62a3a4af9358513c253b
This commit is contained in:
Jerry Zhang
2021-06-20 19:24:42 -07:00
committed by Facebook GitHub Bot
parent b298013cd5
commit 47d727fe1b
4 changed files with 108 additions and 23 deletions

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.quantized._reference as nnqr
import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
@ -312,12 +313,12 @@ class TestQuantizeFx(QuantizationTestCase):
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
self.assertTrue(is_match(modules, n, pattern))
def _get_conv_linear_test_cases(self):
def _get_conv_linear_test_cases(self, is_reference):
""" Returns a list of test cases, with format:
is_dynamic, ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_op
"""
class Conv1d(torch.nn.Module):
class FunctionalConv1d(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight)
@ -329,10 +330,20 @@ class TestQuantizeFx(QuantizationTestCase):
def forward(self, x):
return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
class Conv1d(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.conv = torch.nn.Conv1d(*args)
def forward(self, x):
return self.conv(x)
conv1d_input = torch.rand(1, 3, 224)
conv1d_weight = torch.rand(3, 3, 3)
conv1d_module_args = (3, 3, 3)
class Conv2d(torch.nn.Module):
class FunctionalConv2d(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight)
@ -344,10 +355,19 @@ class TestQuantizeFx(QuantizationTestCase):
def forward(self, x):
return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
class Conv2d(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.conv = torch.nn.Conv2d(*args)
def forward(self, x):
return self.conv(x)
conv2d_input = torch.rand(1, 3, 224, 224)
conv2d_weight = torch.rand(3, 3, 3, 3)
conv2d_module_args = (3, 3, 3)
class Conv3d(torch.nn.Module):
class FunctionalConv3d(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight)
@ -367,8 +387,17 @@ class TestQuantizeFx(QuantizationTestCase):
self.groups,
)
class Conv3d(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.conv = torch.nn.Conv3d(*args)
def forward(self, x):
return self.conv(x)
conv3d_input = torch.rand(1, 3, 32, 224, 224)
conv3d_weight = torch.rand(3, 3, 3, 3, 3)
conv3d_module_args = (3, 3, 3)
class Linear(torch.nn.Module):
def __init__(self, weight):
@ -391,37 +420,63 @@ class TestQuantizeFx(QuantizationTestCase):
linear_module_input = torch.rand(8, 5)
# is_dynamic, ModuleClass, module_constructor_inputs,
# inputs, quantized_node, weight_prepack_node
tests = [
(
False,
Conv1d,
FunctionalConv1d,
(conv1d_weight,),
(conv1d_input,),
ns.call_function(torch.ops.quantized.conv1d),
ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) ,
ns.call_function(torch.ops.quantized.conv1d_prepack),
),
(
False,
Conv2d,
FunctionalConv2d,
(conv2d_weight,),
(conv2d_input,),
ns.call_function(torch.ops.quantized.conv2d),
ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d),
ns.call_function(torch.ops.quantized.conv2d_prepack),
),
(
False,
Conv3d,
FunctionalConv3d,
(conv3d_weight,),
(conv3d_input,),
ns.call_function(torch.ops.quantized.conv3d),
ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d),
ns.call_function(torch.ops.quantized.conv3d_prepack),
),
(
False,
Conv1d,
conv1d_module_args,
(conv1d_input,),
ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d),
None
),
(
False,
Conv2d,
conv2d_module_args,
(conv2d_input,),
ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d),
None
),
(
False,
Conv3d,
conv3d_module_args,
(conv3d_input,),
ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d),
None
),
(
True,
Linear,
(linear_weight,),
(linear_input,),
ns.call_function(torch.ops.quantized.linear_dynamic),
None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic),
ns.call_function(torch.ops.quantized.linear_prepack),
),
(
@ -429,7 +484,7 @@ class TestQuantizeFx(QuantizationTestCase):
Linear,
(linear_weight,),
(linear_input,),
ns.call_function(torch.ops.quantized.linear),
ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear),
ns.call_function(torch.ops.quantized.linear_prepack),
),
(
@ -458,7 +513,7 @@ class TestQuantizeFx(QuantizationTestCase):
def test_functional_not_reference(self):
""" Test quantizing functional conv and linear
"""
tests = self._get_conv_linear_test_cases()
tests = self._get_conv_linear_test_cases(is_reference=False)
for (is_dynamic, ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_node) in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
@ -476,17 +531,17 @@ class TestQuantizeFx(QuantizationTestCase):
def test_functional_reference(self):
""" Test quantizing functional conv and linear with reference option
"""
tests = self._get_conv_linear_test_cases()
tests = self._get_conv_linear_test_cases(is_reference=True)
for (is_dynamic, ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_node) in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 0
node_occurrence[quantized_node] = 0
self.checkGraphModeFxOp(
ModuleClass(*module_constructor_inputs),
inputs, quant_type,
expected_node=quantized_node,
expected_node_occurrence=node_occurrence,
is_reference=True)

View File

@ -3,14 +3,14 @@ import torch.nn.quantized._reference as nnqr
import torch.nn.functional as F
class ConvReLU1d(nnqr.Conv1d):
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dequant = x.dequantize()
weight_dequant = self._qweight.dequantize()
float_result = F.conv1d(
x_dequant, weight_dequant, self._bias, self._conv1d_stride,
self._conv1d_padding, self._conv1d_dilation, self.groups)
x_dequant, weight_dequant, self._bias, self._conv1d_stride, # type: ignore[has-type]
self._conv1d_padding, self._conv1d_dilation, self.groups) # type: ignore[has-type]
float_result = F.relu(float_result, inplace=True)
# NEEDFIX: we don't have dtype in the Linear module APIs right now!
result = torch.quantize_per_tensor(
@ -22,7 +22,7 @@ class ConvReLU1d(nnqr.Conv1d):
class ConvReLU2d(nnqr.Conv2d):
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dequant = x.dequantize()
@ -40,7 +40,7 @@ class ConvReLU2d(nnqr.Conv2d):
return "QuantizedConvReLU2d(Reference)"
class ConvReLU3d(nnqr.Conv3d):
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dequant = x.dequantize()

View File

@ -497,7 +497,7 @@ class ConvReluQuantizeHandler(QuantizeHandler):
self.conv.activation_post_process = output_activation_post_process
# 2. select quantized class
qconv_cls = get_static_quant_module_class(
type(self.conv), additional_static_quant_mapping)
type(self.conv), additional_static_quant_mapping, is_reference=is_reference)
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(modules[parent_name], name, quantized)

View File

@ -6,8 +6,10 @@ from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized._reference as nniqr
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized._reference as nnqr
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
@ -20,6 +22,31 @@ from .fake_quantize import (
)
from .utils import get_combined_dict
# Default map for swapping float module to reference quantized modules
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.Conv1d: nnqr.Conv1d,
nn.Conv2d: nnqr.Conv2d,
nn.Conv3d: nnqr.Conv3d,
# nn.Linear, nnqr.Linear,
nni.ConvReLU1d: nniqr.ConvReLU1d,
nni.ConvReLU2d: nniqr.ConvReLU2d,
nni.ConvReLU3d: nniqr.ConvReLU3d,
nni.LinearReLU: nniqr.LinearReLU,
# QAT Modules
# nnqat.Linear: nnqr.Linear,
nnqat.Conv2d: nnqr.Conv2d,
nnqat.Conv3d: nnqr.Conv3d,
nniqat.ConvBn1d: nnqr.Conv1d,
nniqat.ConvBn2d: nnqr.Conv2d,
nniqat.ConvBn3d: nnqr.Conv3d,
nniqat.ConvBnReLU1d: nniqr.ConvReLU1d,
nniqat.ConvBnReLU2d: nniqr.ConvReLU2d,
nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
nniqat.ConvReLU2d: nniqr.ConvReLU2d,
nniqat.ConvReLU3d: nniqr.ConvReLU3d,
# nniqat.LinearReLU: nniqr.LinearReLU,
}
# Default map for swapping float module to quantized ones
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
QuantStub: nnq.Quantize,
@ -134,13 +161,16 @@ def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
def get_static_quant_module_class(
float_module_class: Callable,
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
is_reference: bool = False) -> Any:
r"""n Get the statically quantized module class corresponding to
the floating point module class
"""
if additional_static_quant_mapping is None:
additional_static_quant_mapping = {}
all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
all_mappings = get_combined_dict(
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \