Files
pytorch/test/quantization/pt2e/test_quantize_pt2e_qat.py
Animesh Jain d73416642f [test] Skip testing of source_fn_stack in light of export changes (#165176)
This is in regards to https://github.com/pytorch/pytorch/pull/164691
where we are inlining into nn modules, and therefore it is causing this
test to fail. The test here looks for node.name which is quite different
with inlining.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165176
Approved by: https://github.com/andrewor14
ghstack dependencies: #165172
2025-10-11 00:16:59 +00:00

1175 lines
44 KiB
Python

# Owner(s): ["oncall: quantization"]
import copy
import operator
import unittest
from typing import Any, Optional
import torch
from torch.ao.quantization import (
default_fake_quant,
FusedMovingAvgObsFakeQuantize,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
QConfigMapping,
)
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
from torch.ao.quantization.qconfig import (
default_per_channel_symmetric_qnnpack_qat_qconfig,
default_symmetric_qnnpack_qat_qconfig,
)
from torch.ao.quantization.quantize_fx import prepare_qat_fx
from torch.ao.quantization.quantize_pt2e import (
_convert_to_reference_decomposed_fx,
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import raise_on_run_directly
class PT2EQATTestCase(QuantizationTestCase):
"""
Base QuantizationTestCase for PT2E QAT with some helper methods.
"""
class _BaseConvBnModel(torch.nn.Module):
def __init__(
self,
conv_class: type[torch.nn.Module],
bn_class: type[torch.nn.Module],
has_conv_bias: bool,
has_bn: bool,
has_relu: bool,
**conv_kwargs,
):
super().__init__()
conv_kwargs.setdefault("in_channels", 3)
conv_kwargs.setdefault("out_channels", 3)
conv_kwargs.setdefault("kernel_size", 3)
conv_kwargs.setdefault("bias", has_conv_bias)
self.conv = conv_class(**conv_kwargs)
self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None
self.relu = torch.nn.ReLU() if has_relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
def _get_conv_bn_model(
self,
has_conv_bias: bool = True,
has_bn: bool = True,
has_relu: bool = False,
transpose: bool = False,
**conv_kwargs,
):
"""
Return an instance of a simple test model containing the
conv[-bn][-relu] pattern. By default, this returns a
conv-bn model with conv bias.
"""
return self._BaseConvBnModel(
self.conv_transpose_class if transpose else self.conv_class,
self.bn_class,
has_conv_bias,
has_bn,
has_relu,
**conv_kwargs,
)
def _verify_symmetric_xnnpack_qat_numerics(
self,
model: torch.nn.Module,
example_inputs: tuple[Any, ...],
):
self._verify_symmetric_xnnpack_qat_numerics_helper(
model,
example_inputs,
is_per_channel=True,
)
self._verify_symmetric_xnnpack_qat_numerics_helper(
model,
example_inputs,
is_per_channel=False,
)
def _verify_symmetric_xnnpack_qat_numerics_helper(
self,
model: torch.nn.Module,
example_inputs: tuple[Any, ...],
is_per_channel: bool,
verify_convert: bool = True,
):
"""
Helper method to verify that the QAT numerics for PT2E quantization match those of
FX graph mode quantization for symmetric qnnpack.
"""
# resetting dynamo cache
torch._dynamo.reset()
MANUAL_SEED = 100
# PT2 export
model_pt2e = copy.deepcopy(model)
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(
is_per_channel=is_per_channel, is_qat=True
)
)
model_pt2e = export(model_pt2e, example_inputs, strict=True).module()
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
torch.manual_seed(MANUAL_SEED)
after_prepare_result_pt2e = model_pt2e(*example_inputs)
model_fx = copy.deepcopy(model)
if is_per_channel:
default_qconfig = default_per_channel_symmetric_qnnpack_qat_qconfig
else:
default_qconfig = default_symmetric_qnnpack_qat_qconfig
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
backend_config = get_qnnpack_backend_config()
model_fx = prepare_qat_fx(
model_fx, qconfig_mapping, example_inputs, backend_config=backend_config
)
torch.manual_seed(MANUAL_SEED)
after_prepare_result_fx = model_fx(*example_inputs)
# Verify that numerics match
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
if verify_convert:
# We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx.eval()
model_fx = _convert_to_reference_decomposed_fx(
model_fx,
backend_config=backend_config,
)
quant_result_fx = model_fx(*example_inputs)
self.assertEqual(quant_result_pt2e, quant_result_fx)
def _verify_symmetric_xnnpack_qat_graph(
self,
m: torch.fx.GraphModule,
example_inputs: tuple[Any, ...],
has_relu: bool,
has_bias: bool = True,
is_cuda: bool = False,
expected_conv_literal_args: Optional[tuple[Any, ...]] = None,
# TODO: set this to true by default
verify_convert: bool = False,
):
self._verify_symmetric_xnnpack_qat_graph_helper(
m,
example_inputs,
is_per_channel=True,
has_relu=has_relu,
has_bias=has_bias,
is_cuda=is_cuda,
expected_conv_literal_args=expected_conv_literal_args,
verify_convert=verify_convert,
)
self._verify_symmetric_xnnpack_qat_graph_helper(
m,
example_inputs,
is_per_channel=False,
has_relu=has_relu,
has_bias=has_bias,
is_cuda=is_cuda,
expected_conv_literal_args=expected_conv_literal_args,
verify_convert=verify_convert,
)
def _verify_symmetric_xnnpack_qat_graph_helper(
self,
m: torch.fx.GraphModule,
example_inputs: tuple[Any, ...],
is_per_channel: bool,
has_relu: bool,
has_bias: bool = True,
is_cuda: bool = False,
expected_conv_literal_args: Optional[tuple[Any, ...]] = None,
verify_convert: bool = False,
):
"""
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
with fake quantizes inserted into the correct places.
# TODO: also verify that metadata is copied over to the new nodes.
"""
m = copy.deepcopy(m)
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel, is_qat=True)
)
m = export(m, example_inputs, strict=True).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
# Verify: getitem output activation fake quantize
output_node = list(m.graph.nodes)[-1]
output_fq_node = output_node.args[0][0]
self.assertTrue(output_fq_node.target.startswith("activation_post_process_"))
output_fq_mod = getattr(m, output_fq_node.target)
self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize)
self.assertEqual(
type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver
)
self.assertEqual(output_fq_mod.dtype, torch.int8)
self.assertEqual(output_fq_mod.quant_min, -128)
self.assertEqual(output_fq_mod.quant_max, 127)
# Verify: getitem(bn, 0) or relu(getitem(bn, 0))
if has_relu:
relu_node = output_fq_node.args[0]
bn_node = relu_node.args[0]
self.assertEqual(relu_node.target, torch.ops.aten.relu.default)
else:
relu_node = None
bn_node = output_fq_node.args[0]
# The relu node takes in the output of bn.
# See NOTE [training ir has no getitem for bn node].
self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)
# Verify: conv / scale_factor.reshape [+ bias.reshape]
if has_bias:
add_bias_node = bn_node.args[0]
(div_scale_factor_node, bias_reshape_node) = add_bias_node.args
self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor)
self.assertEqual(bias_reshape_node.target, torch.ops.aten.reshape.default)
else:
div_scale_factor_node = bn_node.args[0]
(conv_node, scale_factor_reshape_node) = div_scale_factor_node.args
conv_op = conv_node.target
self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor)
self.assertTrue(_is_conv_node(conv_node))
self.assertEqual(
scale_factor_reshape_node.target, torch.ops.aten.reshape.default
)
# Verify: conv literal args
if expected_conv_literal_args is not None:
assert len(expected_conv_literal_args) == 6, (
"wrong num conv args, bad test setup"
)
for i in range(6):
if i + 3 < len(conv_node.args):
self.assertEqual(
conv_node.args[i + 3], expected_conv_literal_args[i]
)
# Verify: conv input activation fake quantize
conv_input_fq_node = conv_node.args[0]
conv_input_node = conv_input_fq_node.args[0]
self.assertTrue(
conv_input_fq_node.target.startswith("activation_post_process_")
)
conv_input_fq_mod = getattr(m, conv_input_fq_node.target)
self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize)
self.assertEqual(
type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver
)
self.assertEqual(conv_input_fq_mod.dtype, torch.int8)
self.assertEqual(conv_input_fq_mod.quant_min, -128)
self.assertEqual(conv_input_fq_mod.quant_max, 127)
self.assertTrue(conv_input_node.op, "placeholder")
# Verify: conv weight fake quantize
conv_weight_fq_node = conv_node.args[1]
self.assertTrue(
conv_weight_fq_node.target.startswith("activation_post_process_")
)
conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target)
if is_per_channel:
expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver
else:
expected_weight_observer_type = MovingAverageMinMaxObserver
self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize)
self.assertEqual(
type(conv_weight_fq_mod.activation_post_process),
expected_weight_observer_type,
)
self.assertEqual(conv_weight_fq_mod.dtype, torch.int8)
self.assertEqual(conv_weight_fq_mod.quant_min, -127)
self.assertEqual(conv_weight_fq_mod.quant_max, 127)
# Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias)
zero_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
mul_weight_scale_factor_node = conv_weight_fq_node.args[0]
(
conv_weight_fq_node,
scale_factor_reshape_node,
) = mul_weight_scale_factor_node.args
if has_bias:
self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default)
else:
self.assertTrue(zero_bias_node is None)
self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor)
self.assertEqual(
scale_factor_reshape_node.target, torch.ops.aten.reshape.default
)
# Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps)
scale_factor_node = scale_factor_reshape_node.args[0]
(bn_weight_node, sqrt_node) = scale_factor_node.args
bn_running_var_add_node = sqrt_node.args[0]
(bn_running_var_node, eps) = bn_running_var_add_node.args
self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor)
self.assertTrue("bn.weight" in bn_weight_node.target)
self.assertTrue("bn.running_var" in bn_running_var_node.target)
self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default)
self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor)
self.assertEqual(eps, 1e-5)
# Optionally check the converted graph
if verify_convert:
m = convert_pt2e(m)
m(*example_inputs)
if is_per_channel:
conv_weight_dq_op = (
torch.ops.quantized_decomposed.dequantize_per_channel.default
)
node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 1,
}
else:
conv_weight_dq_op = (
torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 3,
}
node_list = [
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
ns.call_function(conv_weight_dq_op),
ns.call_function(conv_op),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
]
self.checkGraphModuleNodes(
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence,
)
class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
"""
Base TestCase to be used for all conv-bn[-relu] fusion patterns.
"""
# TODO: how can we avoid adding every new test to dynamo/expected_test_failures?
# Otherwise it fails with the following error:
# torch._dynamo.exc.InternalTorchDynamoError:
# 'QuantizationConfig' object has no attribute '__bool__'
def setUp(self):
# NB: Skip the test if this is a base class, this is to handle the test
# discovery logic in buck which finds and runs all tests here including
# the base class which we don't want to run
if self.id() and "_Base" in self.id():
self.skipTest("Skipping test running from base class")
def test_qat_conv_no_bias(self):
m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True)
m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False)
self._verify_symmetric_xnnpack_qat_numerics(m1, self.example_inputs)
self._verify_symmetric_xnnpack_qat_numerics(m2, self.example_inputs)
def test_qat_conv_bn_fusion(self):
m = self._get_conv_bn_model()
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False)
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_qat_conv_bn_fusion_cuda(self):
m = self._get_conv_bn_model().cuda()
example_inputs = (self.example_inputs[0].cuda(),)
self._verify_symmetric_xnnpack_qat_graph(
m,
example_inputs,
has_relu=False,
is_cuda=True,
)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
def test_qat_conv_bn_fusion_literal_args(self):
class M(torch.nn.Module):
def __init__(self, conv_class, bn_class):
super().__init__()
self.conv = conv_class(3, 3, 3, stride=2, padding=4)
self.bn = bn_class(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
assert self.dim in [1, 2]
if self.dim == 1:
# stride, padding, dilation, transposed, output_padding, groups
conv_args = ((2,), (4,), (1,), False, (0,), 1)
example_inputs = (torch.randn(1, 3, 5),)
else:
# stride, padding, dilation, transposed, output_padding, groups
conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1)
example_inputs = (torch.randn(1, 3, 5, 5),)
m = M(self.conv_class, self.bn_class)
self._verify_symmetric_xnnpack_qat_graph(
m,
example_inputs,
has_relu=False,
expected_conv_literal_args=conv_args,
)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
def test_qat_conv_bn_fusion_no_conv_bias(self):
class M2(torch.nn.Module):
"""
Mixed conv + BN with and without conv bias.
"""
def __init__(self, conv_class, bn_class):
super().__init__()
self.conv1 = conv_class(3, 3, 3, bias=False)
self.bn1 = bn_class(3)
self.conv2 = conv_class(3, 3, 3, bias=True)
self.bn2 = bn_class(3)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
return x
m1 = self._get_conv_bn_model(has_conv_bias=False)
m2 = M2(self.conv_class, self.bn_class)
assert self.dim in [1, 2]
if self.dim == 1:
example_inputs = (torch.randn(3, 3, 5),)
else:
example_inputs = (torch.randn(3, 3, 5, 5),)
self._verify_symmetric_xnnpack_qat_graph(
m1,
example_inputs,
has_relu=False,
has_bias=False,
)
self._verify_symmetric_xnnpack_qat_numerics(m1, example_inputs)
self._verify_symmetric_xnnpack_qat_numerics(m2, example_inputs)
def test_qat_conv_bn_relu_fusion(self):
m = self._get_conv_bn_model(has_relu=True)
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True)
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_qat_conv_bn_relu_fusion_cuda(self):
m = self._get_conv_bn_model(has_relu=True).cuda()
example_inputs = (self.example_inputs[0].cuda(),)
self._verify_symmetric_xnnpack_qat_graph(
m,
example_inputs,
has_relu=True,
is_cuda=True,
)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
def test_qat_conv_bn_relu_fusion_no_conv_bias(self):
m = self._get_conv_bn_model(has_conv_bias=False, has_relu=True)
self._verify_symmetric_xnnpack_qat_graph(
m,
self.example_inputs,
has_relu=True,
has_bias=False,
)
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
def test_qat_inplace_add_relu(self):
class M(torch.nn.Module):
def __init__(self, conv_class):
super().__init__()
self.conv = conv_class(1, 1, 1)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
x0 = x
x = self.conv(x)
x += x0
x = self.relu(x)
return x
assert self.dim in [1, 2]
if self.dim == 1:
example_inputs = (torch.randn(1, 1, 3),)
else:
example_inputs = (torch.randn(1, 1, 3, 3),)
m = M(self.conv_class)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
def test_qat_update_shared_qspec(self):
"""
Test the case where nodes used in SharedQuantizationSpec were replaced
during QAT subgraph rewriting.
"""
class M(torch.nn.Module):
def __init__(self, conv_class, bn_class):
super().__init__()
self.conv = conv_class(3, 3, 3)
self.bn = bn_class(3)
self.hardtanh = torch.nn.Hardtanh()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.hardtanh(x)
return x
m = M(self.conv_class, self.bn_class)
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
def test_qat_preserve_source_fn_stack(self):
"""
Test whether `source_fn_stack` is preserved after QAT fusion.
"""
class M(torch.nn.Module):
def __init__(self, conv_class, bn_class, backbone):
super().__init__()
self.conv = conv_class(5, 3, 3)
self.bn = bn_class(3)
self.relu = torch.nn.ReLU()
self.backbone = backbone
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.backbone(x)
return x
assert self.dim in [1, 2]
if self.dim == 1:
example_inputs = (torch.randn(1, 5, 10),)
else:
example_inputs = (torch.randn(1, 5, 10, 10),)
# QAT prepare + convert
backbone = self._get_conv_bn_model(has_relu=True)
m = M(self.conv_class, self.bn_class, backbone)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
m = export(m, example_inputs, strict=True).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
# Extract the conv and relu nodes (bn was folded into conv)
first_conv, first_relu, second_conv, second_relu = None, None, None, None
for n in m.graph.nodes:
if n.target == torch.ops.aten.relu.default:
if first_relu is None:
assert first_conv is None, "bad test setup"
first_relu = n
first_conv = n.args[0]
else:
assert second_conv is None, "bad test setup"
second_relu = n
second_conv = n.args[0]
# Extract the conv weight and bias nodes
def get_conv_weight_and_bias(conv_node: torch.fx.Node):
weight_dq_node = conv_node.args[1]
qweight_node = weight_dq_node.args[0]
bias_node = conv_node.args[2]
assert isinstance(qweight_node, torch.fx.Node)
assert isinstance(bias_node, torch.fx.Node)
return (qweight_node, bias_node)
_, first_conv_bias = get_conv_weight_and_bias(first_conv)
_, second_conv_bias = get_conv_weight_and_bias(second_conv)
# Assert that each set of conv, conv weight, and conv bias are in the same partition
def get_source_fn(node: torch.fx.Node):
# E.g. [('l__self___backbone1_conv', <class 'torch.nn.modules.conv.Conv2d'>)]
return node.meta["source_fn_stack"][0][0]
# we don't preserve this is quantized weight currently since it's folded
# but user can attach "quantization_tag" to the node and it will be preserved
# self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight))
# self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight))
self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias))
self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias))
# Assert that different sets of convs and relus have different partitions
self.assertNotEqual(get_source_fn(first_conv), get_source_fn(first_relu))
self.assertNotEqual(get_source_fn(first_conv), get_source_fn(second_conv))
self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu))
self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu))
def test_qat_conv_bn_bias_derived_qspec(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export(m, example_inputs, strict=True).module()
quantizer = ConvBnDerivedBiasQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m(*example_inputs)
# Assert that both weight and bias are quantized
(conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
weight_dq = conv_node.args[1]
bias_dq = conv_node.args[2]
self.assertEqual(
weight_dq.target,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
)
self.assertEqual(
bias_dq.target,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
)
weight_getattr = weight_dq.args[0]
bias_getattr = bias_dq.args[0]
self.assertEqual(
weight_getattr.op,
"get_attr",
)
self.assertEqual(
bias_getattr.op,
"get_attr",
)
# Assert that bias scale = weight scale * input scale
input_dq = conv_node.args[0]
input_scale = input_dq.args[1]
bias_scale = bias_dq.args[1]
weight_scale = weight_dq.args[1]
self.assertEqual(bias_scale, input_scale * weight_scale)
# Assert that args for the bias' quantize and dequantize ops
# are copied correctly after subgraph rewriting
(bias_qmin, bias_qmax, bias_dtype) = bias_dq.args[3:]
self.assertEqual(bias_qmin, -(2**31))
self.assertEqual(bias_qmax, 2**31 - 1)
self.assertEqual(bias_dtype, torch.int32)
def test_qat_per_channel_weight_custom_dtype(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export(m, example_inputs, strict=True).module()
quantizer = ConvBnInt32WeightQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m(*example_inputs)
# Assert that conv weight is quantized per channel
(conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
weight_dq = conv_node.args[1]
self.assertEqual(
weight_dq.target,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
)
weight_getattr = weight_dq.args[0]
self.assertEqual(
weight_getattr.op,
"get_attr",
)
# Assert that args for the weight's dequantize ops
# are copied correctly after subgraph rewriting
(dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:]
self.assertEqual(dq_axis, 0)
self.assertEqual(dq_qmin, 0)
self.assertEqual(dq_qmax, 2**31 - 1)
self.assertEqual(dq_dtype, torch.int32)
def _do_test_qat_conv_transpose_bn(self, has_relu: bool):
# Use different in/out channel sizes to test if conv weight is
# properly transposed in QAT pattern
m = self._get_conv_bn_model(
has_relu=has_relu,
transpose=True,
in_channels=3,
out_channels=5,
kernel_size=3,
)
self._verify_symmetric_xnnpack_qat_graph(
m,
self.example_inputs,
has_relu=has_relu,
verify_convert=True,
)
def test_qat_conv_transpose_bn(self):
self._do_test_qat_conv_transpose_bn(has_relu=False)
def test_qat_conv_transpose_bn_relu(self):
self._do_test_qat_conv_transpose_bn(has_relu=True)
def test_qat_conv_bn_per_channel_weight_bias(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export(m, example_inputs, strict=True).module()
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m(*example_inputs)
# Expected graph:
# x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output
# weight -> q_channel -> dq_channel /
# bias -> q_channel -> dq_channel /
(conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
conv_op = conv_node.target
conv_weight_dq_op = (
torch.ops.quantized_decomposed.dequantize_per_channel.default
)
node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 2,
}
node_list = [
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
ns.call_function(conv_weight_dq_op),
ns.call_function(conv_weight_dq_op),
ns.call_function(conv_op),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
]
self.checkGraphModuleNodes(
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence,
)
def test_fold_bn_erases_bn_node(self):
"""
Ensure the BN node is erased from the graph after folding
it into conv in `convert_pt2e` even in train mode.
"""
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = export(m, self.example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
)
m = prepare_qat_pt2e(m, quantizer)
m = convert_pt2e(m)
(conv_node, bn_node, _) = _get_conv_bn_getitem_nodes(m)
self.assertTrue(conv_node is not None)
self.assertTrue(bn_node is None)
def test_fold_bn_erases_add_node(self):
"""
Test that batch norm stat tracking (which results in an add_ tensor) is removed when folding batch norm.
"""
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = export(m, self.example_inputs, strict=True).module()
def _has_add_(graph):
for node in graph.nodes:
if node.target == torch.ops.aten.add_.Tensor:
return True
return False
# Verify that add_ tensor exists in the exported model (for tracking batch norm stats)
has_add_tensor_before = _has_add_(m.graph)
self.assertTrue(
has_add_tensor_before, "Expected to find add_ tensor in the exported model"
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
)
m = prepare_qat_pt2e(m, quantizer)
m = convert_pt2e(m)
# Verify that add_ tensor is removed in the quantized model
has_add_tensor_after = _has_add_(m.graph)
self.assertFalse(
has_add_tensor_after,
"Expected add_ tensor to be removed in the quantized model",
)
@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
dim = 1
example_inputs = (torch.randn(1, 3, 5),)
conv_class = torch.nn.Conv1d
conv_transpose_class = torch.nn.ConvTranspose1d
bn_class = torch.nn.BatchNorm1d
@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
dim = 2
example_inputs = (torch.randn(1, 3, 5, 5),)
conv_class = torch.nn.Conv2d
conv_transpose_class = torch.nn.ConvTranspose2d
bn_class = torch.nn.BatchNorm2d
def _is_conv_node(n: torch.fx.Node):
return n.op == "call_function" and n.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose1d,
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d,
torch.ops.aten.conv_transpose2d.input,
]
def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule):
"""
Return a 3-tuple of (conv, bn, getitem) nodes from the graph.
"""
model.graph.eliminate_dead_code()
model.recompile()
conv_node = None
bn_node = None
getitem_node = None
for n in model.graph.nodes:
if _is_conv_node(n):
conv_node = n
if n.target in (
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten.batch_norm.default,
):
bn_node = n
if n.target == operator.getitem:
getitem_node = n
assert conv_node is not None, "bad test setup"
return (conv_node, bn_node, getitem_node)
class ConvBnInt32WeightQuantizer(Quantizer):
"""
Dummy quantizer that annotates conv bn in such a way that the weights
are quantized per channel to int32.
"""
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=default_fake_quant,
)
weight_qspec = QuantizationSpec(
dtype=torch.int32,
quant_min=0,
quant_max=2**31 - 1,
qscheme=torch.per_channel_affine,
observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args(
observer=MovingAveragePerChannelMinMaxObserver,
),
)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
conv_node.args[0]: act_qspec,
conv_node.args[1]: weight_qspec,
},
_annotated=True,
)
# See NOTE [training ir has no getitem for bn node].
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
return model
def validate(self, model: torch.fx.GraphModule):
pass
class ConvBnDerivedBiasQuantizer(Quantizer):
"""
Dummy quantizer that annotates conv bn in such a way that the bias qparams are
derived from the conv input activation and weight qparams.
"""
def __init__(self, is_per_channel: bool = False):
super().__init__()
self.is_per_channel = is_per_channel
def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs):
act_scale, _ = obs_or_fqs[0].calculate_qparams()
weight_scale, _ = obs_or_fqs[1].calculate_qparams()
if self.is_per_channel:
bias_scale = act_scale * weight_scale
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
else:
bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32)
bias_zero_point = torch.tensor([0], dtype=torch.int32)
return bias_scale, bias_zero_point
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
if self.is_per_channel:
weight_qscheme = torch.per_channel_symmetric
weight_fq = FusedMovingAvgObsFakeQuantize.with_args(
observer=MovingAveragePerChannelMinMaxObserver,
)
else:
weight_qscheme = torch.per_tensor_affine
weight_fq = default_fake_quant
conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=default_fake_quant,
)
weight_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=weight_qscheme,
observer_or_fake_quant_ctr=weight_fq,
)
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=self._derive_bias_qparams_from_act_and_weight_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=weight_qscheme,
ch_axis=0 if self.is_per_channel else None,
)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
conv_node.args[0]: act_qspec,
conv_node.args[1]: weight_qspec,
conv_node.args[2]: bias_qspec,
},
_annotated=True,
)
# NOTE [training ir has no getitem for bn node].
# getitem is None when we use the training IR. It outputs
# aten.batch_norm.default, which do not need any getitem node.
# In this case, we need to annotate on the batch norm node.
# geteitem node should only be None if we are using training IR.
bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)
return model
def validate(self, model: torch.fx.GraphModule):
pass
@skipIfNoQNNPACK
class TestQuantizePT2EQATModels(PT2EQATTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_qat_resnet18(self):
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18()
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_qat_mobilenet_v2(self):
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.mobilenet_v2()
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
class TwoLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(16, 8, bias=False)
self.linear2 = torch.nn.Linear(8, 8)
def forward(self, x):
return self.linear2(self.linear1(x))
class QATPTQTestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
self.linears = TestQuantizeMixQATAndPTQ.TwoLinear()
self.my_linear = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
linear_out = self.linears(permute_out)
my_linear_out = self.my_linear(linear_out)
# Hardtanh doesnt get quantized via xnnpack quantizer in this test
# because it relies on the propagation rules
# Need to fix this
return torch.nn.functional.hardtanh(my_linear_out)
def _prepare_qat_linears(self, model):
for name, child in model.named_children():
if isinstance(child, (torch.nn.Linear, TestQuantizeMixQATAndPTQ.TwoLinear)):
if isinstance(child, torch.nn.Linear):
in_channels = child.weight.size(1)
else:
in_channels = child.linear1.weight.size(1)
example_input = (torch.rand((1, in_channels)),)
traced_child = export(child, example_input, strict=True).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
is_per_channel=True, is_qat=True
)
quantizer.set_global(quantization_config)
traced_child_prepared = prepare_qat_pt2e(traced_child, quantizer)
setattr(model, name, traced_child_prepared)
else:
self._prepare_qat_linears(child)
def _convert_qat_linears(self, model):
for name, child in model.named_children():
if isinstance(child, torch.fx.GraphModule):
torch.ao.quantization.move_exported_model_to_eval(child)
converted_child = convert_pt2e(child)
setattr(model, name, converted_child)
else:
self._convert_qat_linears(child)
def test_mixing_qat_ptq(self):
example_inputs = (torch.randn(2, 3, 4, 4),)
model = TestQuantizeMixQATAndPTQ.QATPTQTestModule()
self._prepare_qat_linears(model)
model(*example_inputs)
# must be fixed model.eval()
self._convert_qat_linears(model)
model(*example_inputs)
model_pt2e = export(model, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantizer.set_module_type(torch.nn.Linear, None)
quantization_config = get_symmetric_quantization_config()
quantizer.set_global(quantization_config)
model_pt2e = prepare_pt2e(model_pt2e, quantizer)
after_prepare_result_pt2e = model_pt2e(*example_inputs) # noqa: F841
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs) # noqa: F841
exported_model = torch.export.export(model_pt2e, example_inputs, strict=True)
node_occurrence = {
# conv2d: 1 for act, 1 for weight, 1 for output
# 3 x linear: 1 for act, 1 for output
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 8,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 9,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 3,
# There needs to be one for hardtanh
}
self.checkGraphModuleNodes(
exported_model.graph_module, expected_node_occurrence=node_occurrence
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")