mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This PR added constant folding for quantize ops so that instead of storing fp32 weight in the quantized model, we'll get int8/int16 etc. weight Test Plan: python test/test_quantization.py TestQuantizePT2E.test_fold_quantize also will verify in executorch later Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D49399210](https://our.internmc.facebook.com/intern/diff/D49399210) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109343 Approved by: https://github.com/kimishpatel, https://github.com/jgong5
313 lines
11 KiB
Python
313 lines
11 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
import copy
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._export as export
|
|
|
|
from torch.ao.quantization.observer import (
|
|
HistogramObserver,
|
|
MinMaxObserver,
|
|
PlaceholderObserver,
|
|
)
|
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
from torch.ao.quantization.quantizer import (
|
|
QuantizationAnnotation,
|
|
QuantizationSpec,
|
|
Quantizer,
|
|
SharedQuantizationSpec,
|
|
)
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
get_symmetric_quantization_config,
|
|
)
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
|
OP_TO_ANNOTATOR,
|
|
QuantizationConfig,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
|
from torch.testing._internal.common_utils import IS_WINDOWS
|
|
|
|
|
|
class TestHelperModules:
|
|
class Conv2dWithObsSharingOps(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.hardtanh = torch.nn.Hardtanh()
|
|
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.adaptive_avg_pool2d(x)
|
|
x = self.hardtanh(x)
|
|
x = x.view(-1, 3)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
class Conv2dWithSharedDQ(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 1)
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
z = x.view(-1, 3)
|
|
w = self.linear(z)
|
|
|
|
y = self.conv2(x)
|
|
add_output = x + y
|
|
|
|
extra_output = x * 2
|
|
return w, add_output, extra_output
|
|
|
|
class ModuleForDifferentQconfig(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 1)
|
|
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
w = self.adaptive_avg_pool2d(x)
|
|
|
|
y = self.conv2(x)
|
|
add_output = x + y
|
|
|
|
extra_output = x + 2
|
|
return w, add_output, extra_output
|
|
|
|
|
|
_DEQUANTIZE_OPS = [
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
]
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
|
class TestDuplicateDQPass(QuantizationTestCase):
|
|
def _test_duplicate_dq(
|
|
self,
|
|
model,
|
|
example_inputs,
|
|
quantizer,
|
|
):
|
|
m_eager = model.eval()
|
|
|
|
# program capture
|
|
m = copy.deepcopy(m_eager)
|
|
m = export.capture_pre_autograd_graph(
|
|
m,
|
|
example_inputs,
|
|
)
|
|
|
|
m = prepare_pt2e(m, quantizer)
|
|
# Calibrate
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
|
|
pt2_quant_output = m(*example_inputs)
|
|
for n in m.graph.nodes:
|
|
annotation = n.meta.get("quantization_annotation", None)
|
|
if annotation is not None:
|
|
for arg in n.args:
|
|
if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS:
|
|
self.assertEqual(len(arg.users.keys()), 1)
|
|
|
|
def test_no_need_for_duplicate_dq(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> avgpool -> hardtanh -> linear
|
|
Check quantization tags on conv2d, avgpool and linear are correctly set
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 7),)
|
|
self._test_duplicate_dq(
|
|
TestHelperModules.Conv2dWithObsSharingOps(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
)
|
|
|
|
def test_simple_duplicate_dq(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> conv2d -> add
|
|
| |
|
|
--------->
|
|
|
|
|
-----> view_copy --> linear
|
|
|
|
|
-----> mul
|
|
There should be three dq nodes because output for the
|
|
first conv2d is fed to next conv2d, add, and view_copy + linear.
|
|
All three are quantized.
|
|
Thus DQ node is not duplicated for those three uses
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["add"](gm, quantization_config)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 7),)
|
|
self._test_duplicate_dq(
|
|
TestHelperModules.Conv2dWithSharedDQ(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
)
|
|
|
|
def test_no_add_quant_duplicate_dq(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> conv2d -> add
|
|
| |
|
|
--------->
|
|
|
|
|
-----> view_copy --> linear
|
|
|
|
|
-----> mul
|
|
There should be three dq nodes because output for the
|
|
first conv2d is fed to next conv2d, and view_copy + linear.
|
|
Both are quantized.
|
|
However the skip connection to add and mul are not quantized.
|
|
Thus DQ node is not duplicated for those two uses
|
|
"""
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 7),)
|
|
self._test_duplicate_dq(
|
|
TestHelperModules.Conv2dWithSharedDQ(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
)
|
|
|
|
def test_avgpool_use_different_qconfig(self):
|
|
"""
|
|
Model under test
|
|
conv2d -> conv2d -> add
|
|
| |
|
|
--------->
|
|
|
|
|
-----> adaptive_avgpool2d (different qconfig)
|
|
|
|
|
-----> add
|
|
output
|
|
conv2d -> dq -> conv2d -> add
|
|
| |
|
|
-------> dq ----->
|
|
|
|
|
-> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig)
|
|
|
|
|
-> dq -----> add
|
|
"""
|
|
|
|
def _get_uint8_quantization_config():
|
|
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
|
|
act_quantization_spec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
|
eps=2**-12
|
|
),
|
|
)
|
|
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
|
MinMaxObserver
|
|
)
|
|
|
|
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
|
weight_quantization_spec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
ch_axis=0,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
|
|
**extra_args
|
|
),
|
|
)
|
|
|
|
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
|
PlaceholderObserver
|
|
)
|
|
bias_quantization_spec = QuantizationSpec(
|
|
dtype=torch.float,
|
|
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr,
|
|
)
|
|
quantization_config = QuantizationConfig(
|
|
act_quantization_spec,
|
|
act_quantization_spec,
|
|
weight_quantization_spec,
|
|
bias_quantization_spec,
|
|
)
|
|
return quantization_config
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
backend_string = "BackendA"
|
|
quantization_config = get_symmetric_quantization_config(
|
|
is_per_channel=True
|
|
)
|
|
avgpool_qconfig = _get_uint8_quantization_config()
|
|
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
|
OP_TO_ANNOTATOR["add"](gm, quantization_config)
|
|
for n in gm.graph.nodes:
|
|
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
|
|
qspec = avgpool_qconfig.input_activation
|
|
input_act = n.args[0]
|
|
output_qspec = SharedQuantizationSpec((input_act, n))
|
|
n.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={input_act: qspec},
|
|
output_qspec=output_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 7),)
|
|
self._test_duplicate_dq(
|
|
TestHelperModules.ModuleForDifferentQconfig(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
)
|