mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is in regards to https://github.com/pytorch/pytorch/pull/164691 where we are inlining into nn modules, and therefore it is causing this test to fail. The test here looks for node.name which is quite different with inlining. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165176 Approved by: https://github.com/andrewor14 ghstack dependencies: #165172
1175 lines
44 KiB
Python
1175 lines
44 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
import copy
|
|
import operator
|
|
import unittest
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch.ao.quantization import (
|
|
default_fake_quant,
|
|
FusedMovingAvgObsFakeQuantize,
|
|
MovingAverageMinMaxObserver,
|
|
MovingAveragePerChannelMinMaxObserver,
|
|
QConfigMapping,
|
|
)
|
|
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
|
|
from torch.ao.quantization.qconfig import (
|
|
default_per_channel_symmetric_qnnpack_qat_qconfig,
|
|
default_symmetric_qnnpack_qat_qconfig,
|
|
)
|
|
from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
|
from torch.ao.quantization.quantize_pt2e import (
|
|
_convert_to_reference_decomposed_fx,
|
|
convert_pt2e,
|
|
prepare_pt2e,
|
|
prepare_qat_pt2e,
|
|
)
|
|
from torch.ao.quantization.quantizer import (
|
|
DerivedQuantizationSpec,
|
|
QuantizationAnnotation,
|
|
QuantizationSpec,
|
|
Quantizer,
|
|
)
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
get_symmetric_quantization_config,
|
|
XNNPACKQuantizer,
|
|
)
|
|
from torch.export import export
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_quantization import (
|
|
NodeSpec as ns,
|
|
QuantizationTestCase,
|
|
skip_if_no_torchvision,
|
|
skipIfNoQNNPACK,
|
|
)
|
|
from torch.testing._internal.common_quantized import override_quantized_engine
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
|
|
|
|
class PT2EQATTestCase(QuantizationTestCase):
|
|
"""
|
|
Base QuantizationTestCase for PT2E QAT with some helper methods.
|
|
"""
|
|
|
|
class _BaseConvBnModel(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
conv_class: type[torch.nn.Module],
|
|
bn_class: type[torch.nn.Module],
|
|
has_conv_bias: bool,
|
|
has_bn: bool,
|
|
has_relu: bool,
|
|
**conv_kwargs,
|
|
):
|
|
super().__init__()
|
|
conv_kwargs.setdefault("in_channels", 3)
|
|
conv_kwargs.setdefault("out_channels", 3)
|
|
conv_kwargs.setdefault("kernel_size", 3)
|
|
conv_kwargs.setdefault("bias", has_conv_bias)
|
|
self.conv = conv_class(**conv_kwargs)
|
|
self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None
|
|
self.relu = torch.nn.ReLU() if has_relu else None
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn is not None:
|
|
x = self.bn(x)
|
|
if self.relu is not None:
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
def _get_conv_bn_model(
|
|
self,
|
|
has_conv_bias: bool = True,
|
|
has_bn: bool = True,
|
|
has_relu: bool = False,
|
|
transpose: bool = False,
|
|
**conv_kwargs,
|
|
):
|
|
"""
|
|
Return an instance of a simple test model containing the
|
|
conv[-bn][-relu] pattern. By default, this returns a
|
|
conv-bn model with conv bias.
|
|
"""
|
|
return self._BaseConvBnModel(
|
|
self.conv_transpose_class if transpose else self.conv_class,
|
|
self.bn_class,
|
|
has_conv_bias,
|
|
has_bn,
|
|
has_relu,
|
|
**conv_kwargs,
|
|
)
|
|
|
|
def _verify_symmetric_xnnpack_qat_numerics(
|
|
self,
|
|
model: torch.nn.Module,
|
|
example_inputs: tuple[Any, ...],
|
|
):
|
|
self._verify_symmetric_xnnpack_qat_numerics_helper(
|
|
model,
|
|
example_inputs,
|
|
is_per_channel=True,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_numerics_helper(
|
|
model,
|
|
example_inputs,
|
|
is_per_channel=False,
|
|
)
|
|
|
|
def _verify_symmetric_xnnpack_qat_numerics_helper(
|
|
self,
|
|
model: torch.nn.Module,
|
|
example_inputs: tuple[Any, ...],
|
|
is_per_channel: bool,
|
|
verify_convert: bool = True,
|
|
):
|
|
"""
|
|
Helper method to verify that the QAT numerics for PT2E quantization match those of
|
|
FX graph mode quantization for symmetric qnnpack.
|
|
"""
|
|
# resetting dynamo cache
|
|
torch._dynamo.reset()
|
|
MANUAL_SEED = 100
|
|
|
|
# PT2 export
|
|
|
|
model_pt2e = copy.deepcopy(model)
|
|
quantizer = XNNPACKQuantizer()
|
|
quantizer.set_global(
|
|
get_symmetric_quantization_config(
|
|
is_per_channel=is_per_channel, is_qat=True
|
|
)
|
|
)
|
|
model_pt2e = export(model_pt2e, example_inputs, strict=True).module()
|
|
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
|
|
torch.manual_seed(MANUAL_SEED)
|
|
after_prepare_result_pt2e = model_pt2e(*example_inputs)
|
|
|
|
model_fx = copy.deepcopy(model)
|
|
if is_per_channel:
|
|
default_qconfig = default_per_channel_symmetric_qnnpack_qat_qconfig
|
|
else:
|
|
default_qconfig = default_symmetric_qnnpack_qat_qconfig
|
|
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
|
|
backend_config = get_qnnpack_backend_config()
|
|
model_fx = prepare_qat_fx(
|
|
model_fx, qconfig_mapping, example_inputs, backend_config=backend_config
|
|
)
|
|
torch.manual_seed(MANUAL_SEED)
|
|
after_prepare_result_fx = model_fx(*example_inputs)
|
|
|
|
# Verify that numerics match
|
|
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
|
|
|
|
if verify_convert:
|
|
# We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
|
|
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
|
|
model_pt2e = convert_pt2e(model_pt2e)
|
|
quant_result_pt2e = model_pt2e(*example_inputs)
|
|
model_fx.eval()
|
|
model_fx = _convert_to_reference_decomposed_fx(
|
|
model_fx,
|
|
backend_config=backend_config,
|
|
)
|
|
quant_result_fx = model_fx(*example_inputs)
|
|
self.assertEqual(quant_result_pt2e, quant_result_fx)
|
|
|
|
def _verify_symmetric_xnnpack_qat_graph(
|
|
self,
|
|
m: torch.fx.GraphModule,
|
|
example_inputs: tuple[Any, ...],
|
|
has_relu: bool,
|
|
has_bias: bool = True,
|
|
is_cuda: bool = False,
|
|
expected_conv_literal_args: Optional[tuple[Any, ...]] = None,
|
|
# TODO: set this to true by default
|
|
verify_convert: bool = False,
|
|
):
|
|
self._verify_symmetric_xnnpack_qat_graph_helper(
|
|
m,
|
|
example_inputs,
|
|
is_per_channel=True,
|
|
has_relu=has_relu,
|
|
has_bias=has_bias,
|
|
is_cuda=is_cuda,
|
|
expected_conv_literal_args=expected_conv_literal_args,
|
|
verify_convert=verify_convert,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_graph_helper(
|
|
m,
|
|
example_inputs,
|
|
is_per_channel=False,
|
|
has_relu=has_relu,
|
|
has_bias=has_bias,
|
|
is_cuda=is_cuda,
|
|
expected_conv_literal_args=expected_conv_literal_args,
|
|
verify_convert=verify_convert,
|
|
)
|
|
|
|
def _verify_symmetric_xnnpack_qat_graph_helper(
|
|
self,
|
|
m: torch.fx.GraphModule,
|
|
example_inputs: tuple[Any, ...],
|
|
is_per_channel: bool,
|
|
has_relu: bool,
|
|
has_bias: bool = True,
|
|
is_cuda: bool = False,
|
|
expected_conv_literal_args: Optional[tuple[Any, ...]] = None,
|
|
verify_convert: bool = False,
|
|
):
|
|
"""
|
|
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
|
|
with fake quantizes inserted into the correct places.
|
|
# TODO: also verify that metadata is copied over to the new nodes.
|
|
"""
|
|
m = copy.deepcopy(m)
|
|
quantizer = XNNPACKQuantizer()
|
|
quantizer.set_global(
|
|
get_symmetric_quantization_config(is_per_channel, is_qat=True)
|
|
)
|
|
m = export(m, example_inputs, strict=True).module()
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
|
|
# Verify: getitem output activation fake quantize
|
|
output_node = list(m.graph.nodes)[-1]
|
|
output_fq_node = output_node.args[0][0]
|
|
self.assertTrue(output_fq_node.target.startswith("activation_post_process_"))
|
|
output_fq_mod = getattr(m, output_fq_node.target)
|
|
self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize)
|
|
self.assertEqual(
|
|
type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver
|
|
)
|
|
self.assertEqual(output_fq_mod.dtype, torch.int8)
|
|
self.assertEqual(output_fq_mod.quant_min, -128)
|
|
self.assertEqual(output_fq_mod.quant_max, 127)
|
|
|
|
# Verify: getitem(bn, 0) or relu(getitem(bn, 0))
|
|
if has_relu:
|
|
relu_node = output_fq_node.args[0]
|
|
bn_node = relu_node.args[0]
|
|
self.assertEqual(relu_node.target, torch.ops.aten.relu.default)
|
|
else:
|
|
relu_node = None
|
|
bn_node = output_fq_node.args[0]
|
|
|
|
# The relu node takes in the output of bn.
|
|
# See NOTE [training ir has no getitem for bn node].
|
|
self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)
|
|
|
|
# Verify: conv / scale_factor.reshape [+ bias.reshape]
|
|
if has_bias:
|
|
add_bias_node = bn_node.args[0]
|
|
(div_scale_factor_node, bias_reshape_node) = add_bias_node.args
|
|
self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor)
|
|
self.assertEqual(bias_reshape_node.target, torch.ops.aten.reshape.default)
|
|
else:
|
|
div_scale_factor_node = bn_node.args[0]
|
|
(conv_node, scale_factor_reshape_node) = div_scale_factor_node.args
|
|
conv_op = conv_node.target
|
|
self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor)
|
|
self.assertTrue(_is_conv_node(conv_node))
|
|
self.assertEqual(
|
|
scale_factor_reshape_node.target, torch.ops.aten.reshape.default
|
|
)
|
|
|
|
# Verify: conv literal args
|
|
if expected_conv_literal_args is not None:
|
|
assert len(expected_conv_literal_args) == 6, (
|
|
"wrong num conv args, bad test setup"
|
|
)
|
|
for i in range(6):
|
|
if i + 3 < len(conv_node.args):
|
|
self.assertEqual(
|
|
conv_node.args[i + 3], expected_conv_literal_args[i]
|
|
)
|
|
|
|
# Verify: conv input activation fake quantize
|
|
conv_input_fq_node = conv_node.args[0]
|
|
conv_input_node = conv_input_fq_node.args[0]
|
|
self.assertTrue(
|
|
conv_input_fq_node.target.startswith("activation_post_process_")
|
|
)
|
|
conv_input_fq_mod = getattr(m, conv_input_fq_node.target)
|
|
self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize)
|
|
self.assertEqual(
|
|
type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver
|
|
)
|
|
self.assertEqual(conv_input_fq_mod.dtype, torch.int8)
|
|
self.assertEqual(conv_input_fq_mod.quant_min, -128)
|
|
self.assertEqual(conv_input_fq_mod.quant_max, 127)
|
|
self.assertTrue(conv_input_node.op, "placeholder")
|
|
|
|
# Verify: conv weight fake quantize
|
|
conv_weight_fq_node = conv_node.args[1]
|
|
self.assertTrue(
|
|
conv_weight_fq_node.target.startswith("activation_post_process_")
|
|
)
|
|
conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target)
|
|
if is_per_channel:
|
|
expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver
|
|
else:
|
|
expected_weight_observer_type = MovingAverageMinMaxObserver
|
|
self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize)
|
|
self.assertEqual(
|
|
type(conv_weight_fq_mod.activation_post_process),
|
|
expected_weight_observer_type,
|
|
)
|
|
self.assertEqual(conv_weight_fq_mod.dtype, torch.int8)
|
|
self.assertEqual(conv_weight_fq_mod.quant_min, -127)
|
|
self.assertEqual(conv_weight_fq_mod.quant_max, 127)
|
|
|
|
# Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias)
|
|
zero_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
|
|
mul_weight_scale_factor_node = conv_weight_fq_node.args[0]
|
|
(
|
|
conv_weight_fq_node,
|
|
scale_factor_reshape_node,
|
|
) = mul_weight_scale_factor_node.args
|
|
if has_bias:
|
|
self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default)
|
|
else:
|
|
self.assertTrue(zero_bias_node is None)
|
|
self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor)
|
|
self.assertEqual(
|
|
scale_factor_reshape_node.target, torch.ops.aten.reshape.default
|
|
)
|
|
|
|
# Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps)
|
|
scale_factor_node = scale_factor_reshape_node.args[0]
|
|
(bn_weight_node, sqrt_node) = scale_factor_node.args
|
|
bn_running_var_add_node = sqrt_node.args[0]
|
|
(bn_running_var_node, eps) = bn_running_var_add_node.args
|
|
self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor)
|
|
self.assertTrue("bn.weight" in bn_weight_node.target)
|
|
self.assertTrue("bn.running_var" in bn_running_var_node.target)
|
|
self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default)
|
|
self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor)
|
|
self.assertEqual(eps, 1e-5)
|
|
|
|
# Optionally check the converted graph
|
|
if verify_convert:
|
|
m = convert_pt2e(m)
|
|
m(*example_inputs)
|
|
|
|
if is_per_channel:
|
|
conv_weight_dq_op = (
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
)
|
|
node_occurrence = {
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
): 1,
|
|
}
|
|
else:
|
|
conv_weight_dq_op = (
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
)
|
|
node_occurrence = {
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 3,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(conv_weight_dq_op),
|
|
ns.call_function(conv_op),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
]
|
|
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=node_list,
|
|
expected_node_occurrence=node_occurrence,
|
|
)
|
|
|
|
|
|
class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
|
"""
|
|
Base TestCase to be used for all conv-bn[-relu] fusion patterns.
|
|
"""
|
|
|
|
# TODO: how can we avoid adding every new test to dynamo/expected_test_failures?
|
|
# Otherwise it fails with the following error:
|
|
# torch._dynamo.exc.InternalTorchDynamoError:
|
|
# 'QuantizationConfig' object has no attribute '__bool__'
|
|
|
|
def setUp(self):
|
|
# NB: Skip the test if this is a base class, this is to handle the test
|
|
# discovery logic in buck which finds and runs all tests here including
|
|
# the base class which we don't want to run
|
|
if self.id() and "_Base" in self.id():
|
|
self.skipTest("Skipping test running from base class")
|
|
|
|
def test_qat_conv_no_bias(self):
|
|
m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True)
|
|
m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m1, self.example_inputs)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m2, self.example_inputs)
|
|
|
|
def test_qat_conv_bn_fusion(self):
|
|
m = self._get_conv_bn_model()
|
|
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_qat_conv_bn_fusion_cuda(self):
|
|
m = self._get_conv_bn_model().cuda()
|
|
example_inputs = (self.example_inputs[0].cuda(),)
|
|
self._verify_symmetric_xnnpack_qat_graph(
|
|
m,
|
|
example_inputs,
|
|
has_relu=False,
|
|
is_cuda=True,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
|
|
|
|
def test_qat_conv_bn_fusion_literal_args(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, conv_class, bn_class):
|
|
super().__init__()
|
|
self.conv = conv_class(3, 3, 3, stride=2, padding=4)
|
|
self.bn = bn_class(3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
assert self.dim in [1, 2]
|
|
if self.dim == 1:
|
|
# stride, padding, dilation, transposed, output_padding, groups
|
|
conv_args = ((2,), (4,), (1,), False, (0,), 1)
|
|
example_inputs = (torch.randn(1, 3, 5),)
|
|
else:
|
|
# stride, padding, dilation, transposed, output_padding, groups
|
|
conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1)
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
|
|
m = M(self.conv_class, self.bn_class)
|
|
|
|
self._verify_symmetric_xnnpack_qat_graph(
|
|
m,
|
|
example_inputs,
|
|
has_relu=False,
|
|
expected_conv_literal_args=conv_args,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
|
|
|
|
def test_qat_conv_bn_fusion_no_conv_bias(self):
|
|
class M2(torch.nn.Module):
|
|
"""
|
|
Mixed conv + BN with and without conv bias.
|
|
"""
|
|
|
|
def __init__(self, conv_class, bn_class):
|
|
super().__init__()
|
|
self.conv1 = conv_class(3, 3, 3, bias=False)
|
|
self.bn1 = bn_class(3)
|
|
self.conv2 = conv_class(3, 3, 3, bias=True)
|
|
self.bn2 = bn_class(3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
return x
|
|
|
|
m1 = self._get_conv_bn_model(has_conv_bias=False)
|
|
m2 = M2(self.conv_class, self.bn_class)
|
|
|
|
assert self.dim in [1, 2]
|
|
if self.dim == 1:
|
|
example_inputs = (torch.randn(3, 3, 5),)
|
|
else:
|
|
example_inputs = (torch.randn(3, 3, 5, 5),)
|
|
|
|
self._verify_symmetric_xnnpack_qat_graph(
|
|
m1,
|
|
example_inputs,
|
|
has_relu=False,
|
|
has_bias=False,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m1, example_inputs)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m2, example_inputs)
|
|
|
|
def test_qat_conv_bn_relu_fusion(self):
|
|
m = self._get_conv_bn_model(has_relu=True)
|
|
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_qat_conv_bn_relu_fusion_cuda(self):
|
|
m = self._get_conv_bn_model(has_relu=True).cuda()
|
|
example_inputs = (self.example_inputs[0].cuda(),)
|
|
self._verify_symmetric_xnnpack_qat_graph(
|
|
m,
|
|
example_inputs,
|
|
has_relu=True,
|
|
is_cuda=True,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
|
|
|
|
def test_qat_conv_bn_relu_fusion_no_conv_bias(self):
|
|
m = self._get_conv_bn_model(has_conv_bias=False, has_relu=True)
|
|
self._verify_symmetric_xnnpack_qat_graph(
|
|
m,
|
|
self.example_inputs,
|
|
has_relu=True,
|
|
has_bias=False,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
|
|
|
|
def test_qat_inplace_add_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, conv_class):
|
|
super().__init__()
|
|
self.conv = conv_class(1, 1, 1)
|
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
|
|
def forward(self, x):
|
|
x0 = x
|
|
x = self.conv(x)
|
|
x += x0
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
assert self.dim in [1, 2]
|
|
if self.dim == 1:
|
|
example_inputs = (torch.randn(1, 1, 3),)
|
|
else:
|
|
example_inputs = (torch.randn(1, 1, 3, 3),)
|
|
|
|
m = M(self.conv_class)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
|
|
|
|
def test_qat_update_shared_qspec(self):
|
|
"""
|
|
Test the case where nodes used in SharedQuantizationSpec were replaced
|
|
during QAT subgraph rewriting.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, conv_class, bn_class):
|
|
super().__init__()
|
|
self.conv = conv_class(3, 3, 3)
|
|
self.bn = bn_class(3)
|
|
self.hardtanh = torch.nn.Hardtanh()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.hardtanh(x)
|
|
return x
|
|
|
|
m = M(self.conv_class, self.bn_class)
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
|
|
|
|
def test_qat_preserve_source_fn_stack(self):
|
|
"""
|
|
Test whether `source_fn_stack` is preserved after QAT fusion.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, conv_class, bn_class, backbone):
|
|
super().__init__()
|
|
self.conv = conv_class(5, 3, 3)
|
|
self.bn = bn_class(3)
|
|
self.relu = torch.nn.ReLU()
|
|
self.backbone = backbone
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
x = self.backbone(x)
|
|
return x
|
|
|
|
assert self.dim in [1, 2]
|
|
if self.dim == 1:
|
|
example_inputs = (torch.randn(1, 5, 10),)
|
|
else:
|
|
example_inputs = (torch.randn(1, 5, 10, 10),)
|
|
|
|
# QAT prepare + convert
|
|
backbone = self._get_conv_bn_model(has_relu=True)
|
|
m = M(self.conv_class, self.bn_class, backbone)
|
|
quantizer = XNNPACKQuantizer()
|
|
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
|
|
m = export(m, example_inputs, strict=True).module()
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
|
|
# Extract the conv and relu nodes (bn was folded into conv)
|
|
first_conv, first_relu, second_conv, second_relu = None, None, None, None
|
|
for n in m.graph.nodes:
|
|
if n.target == torch.ops.aten.relu.default:
|
|
if first_relu is None:
|
|
assert first_conv is None, "bad test setup"
|
|
first_relu = n
|
|
first_conv = n.args[0]
|
|
else:
|
|
assert second_conv is None, "bad test setup"
|
|
second_relu = n
|
|
second_conv = n.args[0]
|
|
|
|
# Extract the conv weight and bias nodes
|
|
def get_conv_weight_and_bias(conv_node: torch.fx.Node):
|
|
weight_dq_node = conv_node.args[1]
|
|
qweight_node = weight_dq_node.args[0]
|
|
bias_node = conv_node.args[2]
|
|
assert isinstance(qweight_node, torch.fx.Node)
|
|
assert isinstance(bias_node, torch.fx.Node)
|
|
return (qweight_node, bias_node)
|
|
|
|
_, first_conv_bias = get_conv_weight_and_bias(first_conv)
|
|
_, second_conv_bias = get_conv_weight_and_bias(second_conv)
|
|
|
|
# Assert that each set of conv, conv weight, and conv bias are in the same partition
|
|
def get_source_fn(node: torch.fx.Node):
|
|
# E.g. [('l__self___backbone1_conv', <class 'torch.nn.modules.conv.Conv2d'>)]
|
|
return node.meta["source_fn_stack"][0][0]
|
|
|
|
# we don't preserve this is quantized weight currently since it's folded
|
|
# but user can attach "quantization_tag" to the node and it will be preserved
|
|
# self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight))
|
|
# self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight))
|
|
|
|
self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias))
|
|
self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias))
|
|
|
|
# Assert that different sets of convs and relus have different partitions
|
|
self.assertNotEqual(get_source_fn(first_conv), get_source_fn(first_relu))
|
|
self.assertNotEqual(get_source_fn(first_conv), get_source_fn(second_conv))
|
|
self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu))
|
|
self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu))
|
|
|
|
def test_qat_conv_bn_bias_derived_qspec(self):
|
|
m = self._get_conv_bn_model()
|
|
example_inputs = self.example_inputs
|
|
m = export(m, example_inputs, strict=True).module()
|
|
quantizer = ConvBnDerivedBiasQuantizer()
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
m(*example_inputs)
|
|
|
|
# Assert that both weight and bias are quantized
|
|
(conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
|
|
weight_dq = conv_node.args[1]
|
|
bias_dq = conv_node.args[2]
|
|
self.assertEqual(
|
|
weight_dq.target,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
)
|
|
self.assertEqual(
|
|
bias_dq.target,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
)
|
|
weight_getattr = weight_dq.args[0]
|
|
bias_getattr = bias_dq.args[0]
|
|
self.assertEqual(
|
|
weight_getattr.op,
|
|
"get_attr",
|
|
)
|
|
self.assertEqual(
|
|
bias_getattr.op,
|
|
"get_attr",
|
|
)
|
|
|
|
# Assert that bias scale = weight scale * input scale
|
|
input_dq = conv_node.args[0]
|
|
input_scale = input_dq.args[1]
|
|
bias_scale = bias_dq.args[1]
|
|
weight_scale = weight_dq.args[1]
|
|
self.assertEqual(bias_scale, input_scale * weight_scale)
|
|
|
|
# Assert that args for the bias' quantize and dequantize ops
|
|
# are copied correctly after subgraph rewriting
|
|
(bias_qmin, bias_qmax, bias_dtype) = bias_dq.args[3:]
|
|
self.assertEqual(bias_qmin, -(2**31))
|
|
self.assertEqual(bias_qmax, 2**31 - 1)
|
|
self.assertEqual(bias_dtype, torch.int32)
|
|
|
|
def test_qat_per_channel_weight_custom_dtype(self):
|
|
m = self._get_conv_bn_model()
|
|
example_inputs = self.example_inputs
|
|
m = export(m, example_inputs, strict=True).module()
|
|
quantizer = ConvBnInt32WeightQuantizer()
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
m(*example_inputs)
|
|
|
|
# Assert that conv weight is quantized per channel
|
|
(conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
|
|
weight_dq = conv_node.args[1]
|
|
self.assertEqual(
|
|
weight_dq.target,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
)
|
|
weight_getattr = weight_dq.args[0]
|
|
self.assertEqual(
|
|
weight_getattr.op,
|
|
"get_attr",
|
|
)
|
|
|
|
# Assert that args for the weight's dequantize ops
|
|
# are copied correctly after subgraph rewriting
|
|
(dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:]
|
|
self.assertEqual(dq_axis, 0)
|
|
self.assertEqual(dq_qmin, 0)
|
|
self.assertEqual(dq_qmax, 2**31 - 1)
|
|
self.assertEqual(dq_dtype, torch.int32)
|
|
|
|
def _do_test_qat_conv_transpose_bn(self, has_relu: bool):
|
|
# Use different in/out channel sizes to test if conv weight is
|
|
# properly transposed in QAT pattern
|
|
m = self._get_conv_bn_model(
|
|
has_relu=has_relu,
|
|
transpose=True,
|
|
in_channels=3,
|
|
out_channels=5,
|
|
kernel_size=3,
|
|
)
|
|
self._verify_symmetric_xnnpack_qat_graph(
|
|
m,
|
|
self.example_inputs,
|
|
has_relu=has_relu,
|
|
verify_convert=True,
|
|
)
|
|
|
|
def test_qat_conv_transpose_bn(self):
|
|
self._do_test_qat_conv_transpose_bn(has_relu=False)
|
|
|
|
def test_qat_conv_transpose_bn_relu(self):
|
|
self._do_test_qat_conv_transpose_bn(has_relu=True)
|
|
|
|
def test_qat_conv_bn_per_channel_weight_bias(self):
|
|
m = self._get_conv_bn_model()
|
|
example_inputs = self.example_inputs
|
|
m = export(m, example_inputs, strict=True).module()
|
|
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
m(*example_inputs)
|
|
|
|
# Expected graph:
|
|
# x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output
|
|
# weight -> q_channel -> dq_channel /
|
|
# bias -> q_channel -> dq_channel /
|
|
|
|
(conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
|
|
conv_op = conv_node.target
|
|
conv_weight_dq_op = (
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
)
|
|
node_occurrence = {
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
): 2,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(conv_weight_dq_op),
|
|
ns.call_function(conv_weight_dq_op),
|
|
ns.call_function(conv_op),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m,
|
|
expected_node_list=node_list,
|
|
expected_node_occurrence=node_occurrence,
|
|
)
|
|
|
|
def test_fold_bn_erases_bn_node(self):
|
|
"""
|
|
Ensure the BN node is erased from the graph after folding
|
|
it into conv in `convert_pt2e` even in train mode.
|
|
"""
|
|
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
|
m = export(m, self.example_inputs, strict=True).module()
|
|
quantizer = XNNPACKQuantizer()
|
|
quantizer.set_global(
|
|
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
|
|
)
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
m = convert_pt2e(m)
|
|
(conv_node, bn_node, _) = _get_conv_bn_getitem_nodes(m)
|
|
self.assertTrue(conv_node is not None)
|
|
self.assertTrue(bn_node is None)
|
|
|
|
def test_fold_bn_erases_add_node(self):
|
|
"""
|
|
Test that batch norm stat tracking (which results in an add_ tensor) is removed when folding batch norm.
|
|
"""
|
|
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
|
m = export(m, self.example_inputs, strict=True).module()
|
|
|
|
def _has_add_(graph):
|
|
for node in graph.nodes:
|
|
if node.target == torch.ops.aten.add_.Tensor:
|
|
return True
|
|
return False
|
|
|
|
# Verify that add_ tensor exists in the exported model (for tracking batch norm stats)
|
|
has_add_tensor_before = _has_add_(m.graph)
|
|
self.assertTrue(
|
|
has_add_tensor_before, "Expected to find add_ tensor in the exported model"
|
|
)
|
|
|
|
quantizer = XNNPACKQuantizer()
|
|
quantizer.set_global(
|
|
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
|
|
)
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
m = convert_pt2e(m)
|
|
|
|
# Verify that add_ tensor is removed in the quantized model
|
|
has_add_tensor_after = _has_add_(m.graph)
|
|
self.assertFalse(
|
|
has_add_tensor_after,
|
|
"Expected add_ tensor to be removed in the quantized model",
|
|
)
|
|
|
|
|
|
@skipIfNoQNNPACK
|
|
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
|
|
dim = 1
|
|
example_inputs = (torch.randn(1, 3, 5),)
|
|
conv_class = torch.nn.Conv1d
|
|
conv_transpose_class = torch.nn.ConvTranspose1d
|
|
bn_class = torch.nn.BatchNorm1d
|
|
|
|
|
|
@skipIfNoQNNPACK
|
|
class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
|
|
dim = 2
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
conv_class = torch.nn.Conv2d
|
|
conv_transpose_class = torch.nn.ConvTranspose2d
|
|
bn_class = torch.nn.BatchNorm2d
|
|
|
|
|
|
def _is_conv_node(n: torch.fx.Node):
|
|
return n.op == "call_function" and n.target in [
|
|
torch.ops.aten.conv1d.default,
|
|
torch.ops.aten.conv2d.default,
|
|
torch.ops.aten.conv_transpose1d,
|
|
torch.ops.aten.conv_transpose1d.default,
|
|
torch.ops.aten.conv_transpose2d,
|
|
torch.ops.aten.conv_transpose2d.input,
|
|
]
|
|
|
|
|
|
def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule):
|
|
"""
|
|
Return a 3-tuple of (conv, bn, getitem) nodes from the graph.
|
|
"""
|
|
model.graph.eliminate_dead_code()
|
|
model.recompile()
|
|
conv_node = None
|
|
bn_node = None
|
|
getitem_node = None
|
|
for n in model.graph.nodes:
|
|
if _is_conv_node(n):
|
|
conv_node = n
|
|
if n.target in (
|
|
torch.ops.aten._native_batch_norm_legit.default,
|
|
torch.ops.aten.batch_norm.default,
|
|
):
|
|
bn_node = n
|
|
if n.target == operator.getitem:
|
|
getitem_node = n
|
|
assert conv_node is not None, "bad test setup"
|
|
return (conv_node, bn_node, getitem_node)
|
|
|
|
|
|
class ConvBnInt32WeightQuantizer(Quantizer):
|
|
"""
|
|
Dummy quantizer that annotates conv bn in such a way that the weights
|
|
are quantized per channel to int32.
|
|
"""
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
observer_or_fake_quant_ctr=default_fake_quant,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int32,
|
|
quant_min=0,
|
|
quant_max=2**31 - 1,
|
|
qscheme=torch.per_channel_affine,
|
|
observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args(
|
|
observer=MovingAveragePerChannelMinMaxObserver,
|
|
),
|
|
)
|
|
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
conv_node.args[0]: act_qspec,
|
|
conv_node.args[1]: weight_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
|
|
# See NOTE [training ir has no getitem for bn node].
|
|
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
return model
|
|
|
|
def validate(self, model: torch.fx.GraphModule):
|
|
pass
|
|
|
|
|
|
class ConvBnDerivedBiasQuantizer(Quantizer):
|
|
"""
|
|
Dummy quantizer that annotates conv bn in such a way that the bias qparams are
|
|
derived from the conv input activation and weight qparams.
|
|
"""
|
|
|
|
def __init__(self, is_per_channel: bool = False):
|
|
super().__init__()
|
|
self.is_per_channel = is_per_channel
|
|
|
|
def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs):
|
|
act_scale, _ = obs_or_fqs[0].calculate_qparams()
|
|
weight_scale, _ = obs_or_fqs[1].calculate_qparams()
|
|
if self.is_per_channel:
|
|
bias_scale = act_scale * weight_scale
|
|
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
|
|
else:
|
|
bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32)
|
|
bias_zero_point = torch.tensor([0], dtype=torch.int32)
|
|
return bias_scale, bias_zero_point
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
if self.is_per_channel:
|
|
weight_qscheme = torch.per_channel_symmetric
|
|
weight_fq = FusedMovingAvgObsFakeQuantize.with_args(
|
|
observer=MovingAveragePerChannelMinMaxObserver,
|
|
)
|
|
else:
|
|
weight_qscheme = torch.per_tensor_affine
|
|
weight_fq = default_fake_quant
|
|
conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
observer_or_fake_quant_ctr=default_fake_quant,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=weight_qscheme,
|
|
observer_or_fake_quant_ctr=weight_fq,
|
|
)
|
|
bias_qspec = DerivedQuantizationSpec(
|
|
derived_from=[
|
|
(conv_node.args[0], conv_node),
|
|
(conv_node.args[1], conv_node),
|
|
],
|
|
derive_qparams_fn=self._derive_bias_qparams_from_act_and_weight_qparams,
|
|
dtype=torch.int32,
|
|
quant_min=-(2**31),
|
|
quant_max=2**31 - 1,
|
|
qscheme=weight_qscheme,
|
|
ch_axis=0 if self.is_per_channel else None,
|
|
)
|
|
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
conv_node.args[0]: act_qspec,
|
|
conv_node.args[1]: weight_qspec,
|
|
conv_node.args[2]: bias_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
|
|
# NOTE [training ir has no getitem for bn node].
|
|
# getitem is None when we use the training IR. It outputs
|
|
# aten.batch_norm.default, which do not need any getitem node.
|
|
# In this case, we need to annotate on the batch norm node.
|
|
# geteitem node should only be None if we are using training IR.
|
|
|
|
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
return model
|
|
|
|
def validate(self, model: torch.fx.GraphModule):
|
|
pass
|
|
|
|
|
|
@skipIfNoQNNPACK
|
|
class TestQuantizePT2EQATModels(PT2EQATTestCase):
|
|
@skip_if_no_torchvision
|
|
@skipIfNoQNNPACK
|
|
def test_qat_resnet18(self):
|
|
import torchvision
|
|
|
|
with override_quantized_engine("qnnpack"):
|
|
example_inputs = (torch.randn(1, 3, 224, 224),)
|
|
m = torchvision.models.resnet18()
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
|
|
|
|
@skip_if_no_torchvision
|
|
@skipIfNoQNNPACK
|
|
def test_qat_mobilenet_v2(self):
|
|
import torchvision
|
|
|
|
with override_quantized_engine("qnnpack"):
|
|
example_inputs = (torch.randn(1, 3, 224, 224),)
|
|
m = torchvision.models.mobilenet_v2()
|
|
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
|
|
|
|
|
|
class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
|
class TwoLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(16, 8, bias=False)
|
|
self.linear2 = torch.nn.Linear(8, 8)
|
|
|
|
def forward(self, x):
|
|
return self.linear2(self.linear1(x))
|
|
|
|
class QATPTQTestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 16, 3)
|
|
self.linears = TestQuantizeMixQATAndPTQ.TwoLinear()
|
|
self.my_linear = torch.nn.Linear(8, 8)
|
|
|
|
def forward(self, x):
|
|
conv_out = self.conv(x)
|
|
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
|
|
linear_out = self.linears(permute_out)
|
|
my_linear_out = self.my_linear(linear_out)
|
|
# Hardtanh doesnt get quantized via xnnpack quantizer in this test
|
|
# because it relies on the propagation rules
|
|
# Need to fix this
|
|
return torch.nn.functional.hardtanh(my_linear_out)
|
|
|
|
def _prepare_qat_linears(self, model):
|
|
for name, child in model.named_children():
|
|
if isinstance(child, (torch.nn.Linear, TestQuantizeMixQATAndPTQ.TwoLinear)):
|
|
if isinstance(child, torch.nn.Linear):
|
|
in_channels = child.weight.size(1)
|
|
else:
|
|
in_channels = child.linear1.weight.size(1)
|
|
|
|
example_input = (torch.rand((1, in_channels)),)
|
|
traced_child = export(child, example_input, strict=True).module()
|
|
quantizer = XNNPACKQuantizer()
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_qat=True
|
|
)
|
|
quantizer.set_global(quantization_config)
|
|
traced_child_prepared = prepare_qat_pt2e(traced_child, quantizer)
|
|
setattr(model, name, traced_child_prepared)
|
|
else:
|
|
self._prepare_qat_linears(child)
|
|
|
|
def _convert_qat_linears(self, model):
|
|
for name, child in model.named_children():
|
|
if isinstance(child, torch.fx.GraphModule):
|
|
torch.ao.quantization.move_exported_model_to_eval(child)
|
|
converted_child = convert_pt2e(child)
|
|
setattr(model, name, converted_child)
|
|
else:
|
|
self._convert_qat_linears(child)
|
|
|
|
def test_mixing_qat_ptq(self):
|
|
example_inputs = (torch.randn(2, 3, 4, 4),)
|
|
model = TestQuantizeMixQATAndPTQ.QATPTQTestModule()
|
|
|
|
self._prepare_qat_linears(model)
|
|
|
|
model(*example_inputs)
|
|
# must be fixed model.eval()
|
|
self._convert_qat_linears(model)
|
|
model(*example_inputs)
|
|
|
|
model_pt2e = export(model, example_inputs, strict=True).module()
|
|
|
|
quantizer = XNNPACKQuantizer()
|
|
quantizer.set_module_type(torch.nn.Linear, None)
|
|
quantization_config = get_symmetric_quantization_config()
|
|
quantizer.set_global(quantization_config)
|
|
model_pt2e = prepare_pt2e(model_pt2e, quantizer)
|
|
after_prepare_result_pt2e = model_pt2e(*example_inputs) # noqa: F841
|
|
model_pt2e = convert_pt2e(model_pt2e)
|
|
quant_result_pt2e = model_pt2e(*example_inputs) # noqa: F841
|
|
|
|
exported_model = torch.export.export(model_pt2e, example_inputs, strict=True)
|
|
|
|
node_occurrence = {
|
|
# conv2d: 1 for act, 1 for weight, 1 for output
|
|
# 3 x linear: 1 for act, 1 for output
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 8,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 9,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
): 3,
|
|
# There needs to be one for hardtanh
|
|
}
|
|
self.checkGraphModuleNodes(
|
|
exported_model.graph_module, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_quantization.py")
|