mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[quant][pt2e] Add more precise representation for quantized add (#104130)
Summary: The planned e2e for quantization in pytorch 2.0 export is the following: float_model -> prepare_pt2e -> calibration -> convert_pt2e -> ... inside convert_pt2e, we will first produce a q/dq representation of the quantized model, similar to the previous output of convert_to_reference_fx in fx grah mode quantization: ``` torch.ops.quantized_decomposed.dequantize_per_tensor -> torch.ops.aten.add -> torch.ops.quantized_decomopsed.quantize_per_tensor torch.ops.quantized_decomposed.dequantize_per_tensor / ``` Then we'll rewrite the above to a more precise representation that express the intention in a more precise manner, since here we actually want to do int8 addition, instead of simulating the int8 addition with fp32 operations, the representation for quantized add is: ``` def quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point): x = (x_scale / out_scale) * x_i8 y = (y_scale / out_scale) * y_i8 out = x + y out -= (x_zero_point * x_scale - y_zero_point * y_scale) / out_scale out += out_zero_point return out ``` Test Plan: ``` buck2 test caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_representation_add (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)' ``` Reviewed By: kimishpatel Differential Revision: D45628032 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104130 Approved by: https://github.com/kimishpatel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7bf27cf163
commit
c98896b76f
@ -1686,6 +1686,59 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
M(), example_inputs, is_per_channel=True, verify_convert=True,
|
||||
)
|
||||
|
||||
def test_representation_add(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
||||
|
||||
quantizer = QNNPackQuantizer()
|
||||
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = M().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
|
||||
# program capture
|
||||
m = m_eager
|
||||
# m_copy = copy.deepcopy(m)
|
||||
m, guards = torchdynamo.export(
|
||||
m,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
)
|
||||
|
||||
m = prepare_pt2e_quantizer(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m, use_reference_representation=True)
|
||||
# make sure it runs
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
|
||||
# TODO: torchdynamo timesout when we do this, we can enable numerical checking
|
||||
# after that is fixed
|
||||
# m_copy = prepare_pt2e_quantizer(m_copy, quantizer)
|
||||
# # Calibrate
|
||||
# m_copy(*example_inputs)
|
||||
# m_copy = convert_pt2e(m_copy, use_reference_representation=False)
|
||||
# pt2_quant_output_copy = m_copy(*example_inputs)
|
||||
|
||||
# output_scale = None
|
||||
# idx = 0
|
||||
# for n in m_copy.graph.nodes:
|
||||
# if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
|
||||
# idx += 1
|
||||
# if idx == 3:
|
||||
# output_scale = n.args[1]
|
||||
# assert output_scale is not None
|
||||
|
||||
# # make sure the result is off by one at most in the quantized integer representation
|
||||
# self.assertTrue(
|
||||
# torch.max(torch.abs(pt2_quant_output_copy - pt2_quant_output)) <= (2 * output_scale + 1e-5)
|
||||
# )
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EOps(QuantizationTestCase):
|
||||
|
@ -141,6 +141,7 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
|
||||
FILENAME_ALLOWLIST |= {
|
||||
_module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py",
|
||||
_module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py",
|
||||
_module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py",
|
||||
}
|
||||
|
||||
# TODO (zhxchen17) Make exportdb importable here.
|
||||
|
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import operator
|
||||
@ -16,6 +15,7 @@ from .quantizer import (
|
||||
QuantizationSpecBase,
|
||||
)
|
||||
from .utils import _fold_bn_weights_into_conv_node
|
||||
from .utils import _get_aten_graph_module
|
||||
|
||||
# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
|
||||
_conv2d_bn_pattern_example_inputs = (
|
||||
@ -299,27 +299,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(
|
||||
return x
|
||||
return _folded_quantized_qat_conv2d_bn_pattern
|
||||
|
||||
def _get_aten_graph_module(
|
||||
pattern: Callable,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
**kwargs,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Convert the pattern to an FX graph with decomposed aten ops.
|
||||
"""
|
||||
# Avoid circular imports
|
||||
import torch._dynamo
|
||||
aten_pattern, _ = torch._dynamo.export(
|
||||
pattern,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
tracing_mode="real",
|
||||
**kwargs,
|
||||
)
|
||||
aten_pattern.graph.eliminate_dead_code()
|
||||
aten_pattern.recompile()
|
||||
return aten_pattern
|
||||
|
||||
def _has_conv_bias_filter(
|
||||
match: "InternalMatch", # type: ignore[name-defined]
|
||||
original_graph: Graph,
|
||||
|
5
torch/ao/quantization/_pt2e/representation/__init__.py
Normal file
5
torch/ao/quantization/_pt2e/representation/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .rewrite import reference_representation_rewrite
|
||||
|
||||
__all__ = [
|
||||
"reference_representation_rewrite",
|
||||
]
|
74
torch/ao/quantization/_pt2e/representation/rewrite.py
Normal file
74
torch/ao/quantization/_pt2e/representation/rewrite.py
Normal file
@ -0,0 +1,74 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from ..utils import _get_aten_graph_module
|
||||
from ..utils import _remove_tensor_overload_for_qdq_ops
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from torch.fx.subgraph_rewriter import replace_pattern
|
||||
|
||||
__all__ = [
|
||||
"reference_representation_rewrite",
|
||||
]
|
||||
|
||||
_QUANTIZED_ADD_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3).to(torch.int8),
|
||||
torch.randn(1).to(torch.float),
|
||||
torch.zeros(1).to(torch.int),
|
||||
torch.randn(1, 3, 3, 3).to(torch.int8),
|
||||
torch.randn(1).to(torch.float),
|
||||
torch.zeros(1).to(torch.int),
|
||||
torch.randn(1).to(torch.float),
|
||||
torch.zeros(1).to(torch.int),
|
||||
)
|
||||
|
||||
def _qdq_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point):
|
||||
quant_min = -128
|
||||
quant_max = 127
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
|
||||
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
|
||||
out_fp32 = x_fp32 + y_fp32
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
return out_i8
|
||||
|
||||
def _reference_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point):
|
||||
"""
|
||||
# How to Derive the formula for out_i8 based on x_i8 and y_i8
|
||||
# (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
|
||||
|
||||
# out_i8 is quantized output, we can write down the formula for it first:
|
||||
out_i8 = out_f32 / out_scale + out_zero_point (1)
|
||||
|
||||
# then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
|
||||
out_f32 = x_f32 + y_f32 (2)
|
||||
x_fp32 = (x_i8 - x_zero_point) * x_scale (3)
|
||||
y_fp32 = (y_i8 - y_zero_point) * y_scale (4)
|
||||
|
||||
# applying the above fomula to the out_i8 equation we can get the following:
|
||||
out_i8 = out_fp32 / out_scale + out_zero_point # (1)
|
||||
= (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
|
||||
= ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4)
|
||||
"""
|
||||
x_i32 = x_i8.to(torch.int32)
|
||||
y_i32 = y_i8.to(torch.int32)
|
||||
# TODO: use out_dtype op
|
||||
x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
|
||||
y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
|
||||
out_i32 = x_i32 + y_i32 + out_zero_point
|
||||
quant_min = -128
|
||||
quant_max = 127
|
||||
out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
_EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
|
||||
(_QUANTIZED_ADD_EXAMPLE_INPUTS, _qdq_quantized_add, _reference_quantized_add)
|
||||
]
|
||||
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
_remove_tensor_overload_for_qdq_ops(model)
|
||||
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
|
||||
pattern_graph = _get_aten_graph_module(pattern, example_inputs)
|
||||
_remove_tensor_overload_for_qdq_ops(pattern_graph)
|
||||
replacement_graph = _get_aten_graph_module(replacement, example_inputs)
|
||||
matches = replace_pattern(model, pattern_graph, replacement_graph)
|
||||
return model
|
@ -9,7 +9,8 @@ from torch.ao.quantization.fx.prepare import (
|
||||
_is_activation_post_process_node,
|
||||
)
|
||||
import operator
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple, Callable, Any
|
||||
import copy
|
||||
|
||||
|
||||
def _get_tensor_constant_from_node(node, m):
|
||||
@ -172,3 +173,42 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
|
||||
current_scope = (bt[0].split(".")[-1], bt[1])
|
||||
node_name_to_scope[n.name] = current_scope
|
||||
return node_name_to_scope
|
||||
|
||||
def _get_aten_graph_module(
|
||||
pattern: Callable,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
**kwargs,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Convert the pattern to an FX graph with decomposed aten ops.
|
||||
"""
|
||||
# Avoid circular imports
|
||||
import torch._dynamo
|
||||
aten_pattern, _ = torch._dynamo.export(
|
||||
pattern,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
tracing_mode="real",
|
||||
**kwargs,
|
||||
)
|
||||
aten_pattern.graph.eliminate_dead_code()
|
||||
aten_pattern.recompile()
|
||||
return aten_pattern
|
||||
|
||||
def _remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||
""" Remove .tensor overload for quantize/dequantize ops so that we can
|
||||
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
|
||||
"""
|
||||
_MAP = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel
|
||||
}
|
||||
for n in match_pattern.graph.nodes:
|
||||
if n.op != "call_function":
|
||||
continue
|
||||
if n.target in _MAP:
|
||||
n.target = _MAP[n.target]
|
||||
|
@ -11,6 +11,7 @@ from ._pt2e.utils import (
|
||||
_fuse_conv_bn_,
|
||||
_rearrange_weight_observer_for_decomposed_linear,
|
||||
)
|
||||
from ._pt2e.representation import reference_representation_rewrite
|
||||
from .fx.prepare import prepare as fx_prepare
|
||||
from .quantize_fx import _convert_to_reference_decomposed_fx
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
@ -82,8 +83,11 @@ def prepare_qat_pt2e_quantizer(
|
||||
return model
|
||||
|
||||
def convert_pt2e(
|
||||
model: GraphModule
|
||||
model: GraphModule,
|
||||
use_reference_representation: bool = False,
|
||||
) -> GraphModule:
|
||||
model = _convert_to_reference_decomposed_fx(model)
|
||||
model = _fold_conv_bn_qat(model)
|
||||
if use_reference_representation:
|
||||
model = reference_representation_rewrite(model)
|
||||
return model
|
||||
|
@ -133,7 +133,7 @@ def dequantize_per_tensor(
|
||||
Returns:
|
||||
dequantized float32 Tensor
|
||||
"""
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
|
||||
if dtype in [torch.uint8, torch.int8, torch.int32]:
|
||||
# TODO: investigate why
|
||||
# (input - zero_point).to(torch.float32) * scale
|
||||
|
Reference in New Issue
Block a user