mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][graphmode][fx] Produce conv reference static quant modules (#60138)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60138 Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: vkuzo Differential Revision: D29184791 fbshipit-source-id: 971a40012dbba0cf687c62a3a4af9358513c253b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b298013cd5
commit
47d727fe1b
@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized._reference as nnqr
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.quantized as nniq
|
||||
@ -312,12 +313,12 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
|
||||
self.assertTrue(is_match(modules, n, pattern))
|
||||
|
||||
def _get_conv_linear_test_cases(self):
|
||||
def _get_conv_linear_test_cases(self, is_reference):
|
||||
""" Returns a list of test cases, with format:
|
||||
is_dynamic, ModuleClass, module_constructor_inputs,
|
||||
inputs, quantized_node, weight_prepack_op
|
||||
"""
|
||||
class Conv1d(torch.nn.Module):
|
||||
class FunctionalConv1d(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
@ -329,10 +330,20 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
def forward(self, x):
|
||||
return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(*args)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
conv1d_input = torch.rand(1, 3, 224)
|
||||
conv1d_weight = torch.rand(3, 3, 3)
|
||||
conv1d_module_args = (3, 3, 3)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
class FunctionalConv2d(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
@ -344,10 +355,19 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(*args)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
conv2d_input = torch.rand(1, 3, 224, 224)
|
||||
conv2d_weight = torch.rand(3, 3, 3, 3)
|
||||
conv2d_module_args = (3, 3, 3)
|
||||
|
||||
class Conv3d(torch.nn.Module):
|
||||
class FunctionalConv3d(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
@ -367,8 +387,17 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
self.groups,
|
||||
)
|
||||
|
||||
class Conv3d(torch.nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv3d(*args)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
conv3d_input = torch.rand(1, 3, 32, 224, 224)
|
||||
conv3d_weight = torch.rand(3, 3, 3, 3, 3)
|
||||
conv3d_module_args = (3, 3, 3)
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
@ -391,37 +420,63 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
|
||||
linear_module_input = torch.rand(8, 5)
|
||||
|
||||
# is_dynamic, ModuleClass, module_constructor_inputs,
|
||||
# inputs, quantized_node, weight_prepack_node
|
||||
tests = [
|
||||
(
|
||||
False,
|
||||
Conv1d,
|
||||
FunctionalConv1d,
|
||||
(conv1d_weight,),
|
||||
(conv1d_input,),
|
||||
ns.call_function(torch.ops.quantized.conv1d),
|
||||
ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) ,
|
||||
ns.call_function(torch.ops.quantized.conv1d_prepack),
|
||||
),
|
||||
(
|
||||
False,
|
||||
Conv2d,
|
||||
FunctionalConv2d,
|
||||
(conv2d_weight,),
|
||||
(conv2d_input,),
|
||||
ns.call_function(torch.ops.quantized.conv2d),
|
||||
ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d),
|
||||
ns.call_function(torch.ops.quantized.conv2d_prepack),
|
||||
),
|
||||
(
|
||||
False,
|
||||
Conv3d,
|
||||
FunctionalConv3d,
|
||||
(conv3d_weight,),
|
||||
(conv3d_input,),
|
||||
ns.call_function(torch.ops.quantized.conv3d),
|
||||
ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d),
|
||||
ns.call_function(torch.ops.quantized.conv3d_prepack),
|
||||
),
|
||||
(
|
||||
False,
|
||||
Conv1d,
|
||||
conv1d_module_args,
|
||||
(conv1d_input,),
|
||||
ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d),
|
||||
None
|
||||
),
|
||||
(
|
||||
False,
|
||||
Conv2d,
|
||||
conv2d_module_args,
|
||||
(conv2d_input,),
|
||||
ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d),
|
||||
None
|
||||
),
|
||||
(
|
||||
False,
|
||||
Conv3d,
|
||||
conv3d_module_args,
|
||||
(conv3d_input,),
|
||||
ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d),
|
||||
None
|
||||
),
|
||||
(
|
||||
True,
|
||||
Linear,
|
||||
(linear_weight,),
|
||||
(linear_input,),
|
||||
ns.call_function(torch.ops.quantized.linear_dynamic),
|
||||
None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic),
|
||||
ns.call_function(torch.ops.quantized.linear_prepack),
|
||||
),
|
||||
(
|
||||
@ -429,7 +484,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
Linear,
|
||||
(linear_weight,),
|
||||
(linear_input,),
|
||||
ns.call_function(torch.ops.quantized.linear),
|
||||
ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear),
|
||||
ns.call_function(torch.ops.quantized.linear_prepack),
|
||||
),
|
||||
(
|
||||
@ -458,7 +513,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
def test_functional_not_reference(self):
|
||||
""" Test quantizing functional conv and linear
|
||||
"""
|
||||
tests = self._get_conv_linear_test_cases()
|
||||
tests = self._get_conv_linear_test_cases(is_reference=False)
|
||||
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
||||
inputs, quantized_node, weight_prepack_node) in tests:
|
||||
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
||||
@ -476,17 +531,17 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
def test_functional_reference(self):
|
||||
""" Test quantizing functional conv and linear with reference option
|
||||
"""
|
||||
tests = self._get_conv_linear_test_cases()
|
||||
tests = self._get_conv_linear_test_cases(is_reference=True)
|
||||
for (is_dynamic, ModuleClass, module_constructor_inputs,
|
||||
inputs, quantized_node, weight_prepack_node) in tests:
|
||||
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
|
||||
node_occurrence = dict()
|
||||
if weight_prepack_node:
|
||||
node_occurrence[weight_prepack_node] = 0
|
||||
node_occurrence[quantized_node] = 0
|
||||
self.checkGraphModeFxOp(
|
||||
ModuleClass(*module_constructor_inputs),
|
||||
inputs, quant_type,
|
||||
expected_node=quantized_node,
|
||||
expected_node_occurrence=node_occurrence,
|
||||
is_reference=True)
|
||||
|
||||
|
||||
@ -3,14 +3,14 @@ import torch.nn.quantized._reference as nnqr
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConvReLU1d(nnqr.Conv1d):
|
||||
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
|
||||
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dequant = x.dequantize()
|
||||
weight_dequant = self._qweight.dequantize()
|
||||
float_result = F.conv1d(
|
||||
x_dequant, weight_dequant, self._bias, self._conv1d_stride,
|
||||
self._conv1d_padding, self._conv1d_dilation, self.groups)
|
||||
x_dequant, weight_dequant, self._bias, self._conv1d_stride, # type: ignore[has-type]
|
||||
self._conv1d_padding, self._conv1d_dilation, self.groups) # type: ignore[has-type]
|
||||
float_result = F.relu(float_result, inplace=True)
|
||||
# NEEDFIX: we don't have dtype in the Linear module APIs right now!
|
||||
result = torch.quantize_per_tensor(
|
||||
@ -22,7 +22,7 @@ class ConvReLU1d(nnqr.Conv1d):
|
||||
|
||||
|
||||
class ConvReLU2d(nnqr.Conv2d):
|
||||
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
|
||||
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dequant = x.dequantize()
|
||||
@ -40,7 +40,7 @@ class ConvReLU2d(nnqr.Conv2d):
|
||||
return "QuantizedConvReLU2d(Reference)"
|
||||
|
||||
class ConvReLU3d(nnqr.Conv3d):
|
||||
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
|
||||
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dequant = x.dequantize()
|
||||
|
||||
@ -497,7 +497,7 @@ class ConvReluQuantizeHandler(QuantizeHandler):
|
||||
self.conv.activation_post_process = output_activation_post_process
|
||||
# 2. select quantized class
|
||||
qconv_cls = get_static_quant_module_class(
|
||||
type(self.conv), additional_static_quant_mapping)
|
||||
type(self.conv), additional_static_quant_mapping, is_reference=is_reference)
|
||||
quantized = qconv_cls.from_float(self.conv)
|
||||
parent_name, name = _parent_name(self.conv_node.target)
|
||||
setattr(modules[parent_name], name, quantized)
|
||||
|
||||
@ -6,8 +6,10 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.quantized as nniq
|
||||
import torch.nn.intrinsic.quantized._reference as nniqr
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized._reference as nnqr
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
import torch.nn.qat as nnqat
|
||||
|
||||
@ -20,6 +22,31 @@ from .fake_quantize import (
|
||||
)
|
||||
from .utils import get_combined_dict
|
||||
|
||||
# Default map for swapping float module to reference quantized modules
|
||||
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||
nn.Conv1d: nnqr.Conv1d,
|
||||
nn.Conv2d: nnqr.Conv2d,
|
||||
nn.Conv3d: nnqr.Conv3d,
|
||||
# nn.Linear, nnqr.Linear,
|
||||
nni.ConvReLU1d: nniqr.ConvReLU1d,
|
||||
nni.ConvReLU2d: nniqr.ConvReLU2d,
|
||||
nni.ConvReLU3d: nniqr.ConvReLU3d,
|
||||
nni.LinearReLU: nniqr.LinearReLU,
|
||||
# QAT Modules
|
||||
# nnqat.Linear: nnqr.Linear,
|
||||
nnqat.Conv2d: nnqr.Conv2d,
|
||||
nnqat.Conv3d: nnqr.Conv3d,
|
||||
nniqat.ConvBn1d: nnqr.Conv1d,
|
||||
nniqat.ConvBn2d: nnqr.Conv2d,
|
||||
nniqat.ConvBn3d: nnqr.Conv3d,
|
||||
nniqat.ConvBnReLU1d: nniqr.ConvReLU1d,
|
||||
nniqat.ConvBnReLU2d: nniqr.ConvReLU2d,
|
||||
nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
|
||||
nniqat.ConvReLU2d: nniqr.ConvReLU2d,
|
||||
nniqat.ConvReLU3d: nniqr.ConvReLU3d,
|
||||
# nniqat.LinearReLU: nniqr.LinearReLU,
|
||||
}
|
||||
|
||||
# Default map for swapping float module to quantized ones
|
||||
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||
QuantStub: nnq.Quantize,
|
||||
@ -134,13 +161,16 @@ def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
|
||||
|
||||
def get_static_quant_module_class(
|
||||
float_module_class: Callable,
|
||||
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
|
||||
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
|
||||
is_reference: bool = False) -> Any:
|
||||
r"""n Get the statically quantized module class corresponding to
|
||||
the floating point module class
|
||||
"""
|
||||
if additional_static_quant_mapping is None:
|
||||
additional_static_quant_mapping = {}
|
||||
all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
|
||||
all_mappings = get_combined_dict(
|
||||
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
|
||||
else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
|
||||
static_quant_module_class = all_mappings.get(float_module_class, None)
|
||||
assert static_quant_module_class is not None, \
|
||||
"Floating point module class {}".format(str(float_module_class)) + \
|
||||
|
||||
Reference in New Issue
Block a user