[quant][pt2e] Add more precise representation for quantized add (#104130)

Summary:
The planned e2e for quantization in pytorch 2.0 export is the following:

float_model -> prepare_pt2e -> calibration -> convert_pt2e -> ...

inside convert_pt2e, we will first produce a q/dq representation of the quantized model, similar to the previous output of
convert_to_reference_fx in fx grah mode quantization:

```
torch.ops.quantized_decomposed.dequantize_per_tensor -> torch.ops.aten.add -> torch.ops.quantized_decomopsed.quantize_per_tensor
torch.ops.quantized_decomposed.dequantize_per_tensor   /
```

Then we'll rewrite the above to a more precise representation that express the intention in a more precise manner, since
here we actually want to do int8 addition, instead of simulating the int8 addition with fp32 operations, the representation for
quantized add is:

```
def quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point):
    x = (x_scale / out_scale) * x_i8
    y = (y_scale / out_scale) * y_i8
    out = x + y
    out -= (x_zero_point * x_scale - y_zero_point * y_scale) / out_scale
    out += out_zero_point
    return out
```

Test Plan:
```
buck2 test caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_representation_add (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)'
```

Reviewed By: kimishpatel

Differential Revision: D45628032

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104130
Approved by: https://github.com/kimishpatel
This commit is contained in:
Jerry Zhang
2023-06-27 20:11:26 +00:00
committed by PyTorch MergeBot
parent 7bf27cf163
commit c98896b76f
8 changed files with 181 additions and 25 deletions

View File

@ -1686,6 +1686,59 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
M(), example_inputs, is_per_channel=True, verify_convert=True,
)
def test_representation_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
# program capture
m = m_eager
# m_copy = copy.deepcopy(m)
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
m = prepare_pt2e_quantizer(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m, use_reference_representation=True)
# make sure it runs
pt2_quant_output = m(*example_inputs)
# TODO: torchdynamo timesout when we do this, we can enable numerical checking
# after that is fixed
# m_copy = prepare_pt2e_quantizer(m_copy, quantizer)
# # Calibrate
# m_copy(*example_inputs)
# m_copy = convert_pt2e(m_copy, use_reference_representation=False)
# pt2_quant_output_copy = m_copy(*example_inputs)
# output_scale = None
# idx = 0
# for n in m_copy.graph.nodes:
# if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
# idx += 1
# if idx == 3:
# output_scale = n.args[1]
# assert output_scale is not None
# # make sure the result is off by one at most in the quantized integer representation
# self.assertTrue(
# torch.max(torch.abs(pt2_quant_output_copy - pt2_quant_output)) <= (2 * output_scale + 1e-5)
# )
@skipIfNoQNNPACK
class TestQuantizePT2EOps(QuantizationTestCase):

View File

@ -141,6 +141,7 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py",
}
# TODO (zhxchen17) Make exportdb importable here.

View File

@ -1,4 +1,3 @@
import copy
import dataclasses
import itertools
import operator
@ -16,6 +15,7 @@ from .quantizer import (
QuantizationSpecBase,
)
from .utils import _fold_bn_weights_into_conv_node
from .utils import _get_aten_graph_module
# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
_conv2d_bn_pattern_example_inputs = (
@ -299,27 +299,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(
return x
return _folded_quantized_qat_conv2d_bn_pattern
def _get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
**kwargs,
) -> GraphModule:
"""
Convert the pattern to an FX graph with decomposed aten ops.
"""
# Avoid circular imports
import torch._dynamo
aten_pattern, _ = torch._dynamo.export(
pattern,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
**kwargs,
)
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
return aten_pattern
def _has_conv_bias_filter(
match: "InternalMatch", # type: ignore[name-defined]
original_graph: Graph,

View File

@ -0,0 +1,5 @@
from .rewrite import reference_representation_rewrite
__all__ = [
"reference_representation_rewrite",
]

View File

@ -0,0 +1,74 @@
import torch
from torch.fx import GraphModule
from ..utils import _get_aten_graph_module
from ..utils import _remove_tensor_overload_for_qdq_ops
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.fx.subgraph_rewriter import replace_pattern
__all__ = [
"reference_representation_rewrite",
]
_QUANTIZED_ADD_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3).to(torch.int8),
torch.randn(1).to(torch.float),
torch.zeros(1).to(torch.int),
torch.randn(1, 3, 3, 3).to(torch.int8),
torch.randn(1).to(torch.float),
torch.zeros(1).to(torch.int),
torch.randn(1).to(torch.float),
torch.zeros(1).to(torch.int),
)
def _qdq_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point):
quant_min = -128
quant_max = 127
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
out_fp32 = x_fp32 + y_fp32
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
)
return out_i8
def _reference_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point):
"""
# How to Derive the formula for out_i8 based on x_i8 and y_i8
# (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
# out_i8 is quantized output, we can write down the formula for it first:
out_i8 = out_f32 / out_scale + out_zero_point (1)
# then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
out_f32 = x_f32 + y_f32 (2)
x_fp32 = (x_i8 - x_zero_point) * x_scale (3)
y_fp32 = (y_i8 - y_zero_point) * y_scale (4)
# applying the above fomula to the out_i8 equation we can get the following:
out_i8 = out_fp32 / out_scale + out_zero_point # (1)
= (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
= ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4)
"""
x_i32 = x_i8.to(torch.int32)
y_i32 = y_i8.to(torch.int32)
# TODO: use out_dtype op
x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
out_i32 = x_i32 + y_i32 + out_zero_point
quant_min = -128
quant_max = 127
out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
return out_i8
_EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
(_QUANTIZED_ADD_EXAMPLE_INPUTS, _qdq_quantized_add, _reference_quantized_add)
]
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
_remove_tensor_overload_for_qdq_ops(model)
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
pattern_graph = _get_aten_graph_module(pattern, example_inputs)
_remove_tensor_overload_for_qdq_ops(pattern_graph)
replacement_graph = _get_aten_graph_module(replacement, example_inputs)
matches = replace_pattern(model, pattern_graph, replacement_graph)
return model

View File

@ -9,7 +9,8 @@ from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
import operator
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Callable, Any
import copy
def _get_tensor_constant_from_node(node, m):
@ -172,3 +173,42 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
current_scope = (bt[0].split(".")[-1], bt[1])
node_name_to_scope[n.name] = current_scope
return node_name_to_scope
def _get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
**kwargs,
) -> GraphModule:
"""
Convert the pattern to an FX graph with decomposed aten ops.
"""
# Avoid circular imports
import torch._dynamo
aten_pattern, _ = torch._dynamo.export(
pattern,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
**kwargs,
)
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
return aten_pattern
def _remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
""" Remove .tensor overload for quantize/dequantize ops so that we can
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
"""
_MAP = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel
}
for n in match_pattern.graph.nodes:
if n.op != "call_function":
continue
if n.target in _MAP:
n.target = _MAP[n.target]

View File

@ -11,6 +11,7 @@ from ._pt2e.utils import (
_fuse_conv_bn_,
_rearrange_weight_observer_for_decomposed_linear,
)
from ._pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from torch.ao.quantization import QConfigMapping
@ -82,8 +83,11 @@ def prepare_qat_pt2e_quantizer(
return model
def convert_pt2e(
model: GraphModule
model: GraphModule,
use_reference_representation: bool = False,
) -> GraphModule:
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
if use_reference_representation:
model = reference_representation_rewrite(model)
return model

View File

@ -133,7 +133,7 @@ def dequantize_per_tensor(
Returns:
dequantized float32 Tensor
"""
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
if dtype in [torch.uint8, torch.int8, torch.int32]:
# TODO: investigate why
# (input - zero_point).to(torch.float32) * scale